diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml
index 6c99b21b..7609a68b 100644
--- a/.github/workflows/docker-image.yml
+++ b/.github/workflows/docker-image.yml
@@ -1,13 +1,14 @@
name: docker-image
on:
+ workflow_dispatch:
push:
tags:
- v*
env:
APP_NAME: CLIProxyAPI
- DOCKERHUB_REPO: eceasy/cli-proxy-api
+ DOCKERHUB_REPO: ${{ secrets.DOCKERHUB_USERNAME }}/cli-proxy-api-plus
jobs:
docker_amd64:
diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml
index 64e7a5b7..04ec21a9 100644
--- a/.github/workflows/release.yaml
+++ b/.github/workflows/release.yaml
@@ -23,7 +23,8 @@ jobs:
cache: true
- name: Generate Build Metadata
run: |
- echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
+ VERSION=$(git describe --tags --always --dirty)
+ echo "VERSION=${VERSION}" >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- uses: goreleaser/goreleaser-action@v4
diff --git a/.gitignore b/.gitignore
index 183138f9..feda9dbf 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,17 +1,20 @@
# Binaries
cli-proxy-api
+cliproxy
*.exe
+
# Configuration
config.yaml
.env
-
+.mcp.json
# Generated content
bin/*
logs/*
conv/*
temp/*
refs/*
+tmp/*
# Storage backends
pgstore/*
@@ -44,7 +47,9 @@ GEMINI.md
.bmad/*
_bmad/*
_bmad-output/*
+.mcp/cache/
# macOS
.DS_Store
._*
+*.bak
diff --git a/.goreleaser.yml b/.goreleaser.yml
index 31d05e6d..6e1829ed 100644
--- a/.goreleaser.yml
+++ b/.goreleaser.yml
@@ -1,5 +1,5 @@
builds:
- - id: "cli-proxy-api"
+ - id: "cli-proxy-api-plus"
env:
- CGO_ENABLED=0
goos:
@@ -10,11 +10,11 @@ builds:
- amd64
- arm64
main: ./cmd/server/
- binary: cli-proxy-api
+ binary: cli-proxy-api-plus
ldflags:
- - -s -w -X 'main.Version={{.Version}}' -X 'main.Commit={{.ShortCommit}}' -X 'main.BuildDate={{.Date}}'
+ - -s -w -X 'main.Version={{.Version}}-plus' -X 'main.Commit={{.ShortCommit}}' -X 'main.BuildDate={{.Date}}'
archives:
- - id: "cli-proxy-api"
+ - id: "cli-proxy-api-plus"
format: tar.gz
format_overrides:
- goos: windows
diff --git a/Dockerfile b/Dockerfile
index 3e10c4f9..cde6205a 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -12,7 +12,7 @@ ARG VERSION=dev
ARG COMMIT=none
ARG BUILD_DATE=unknown
-RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X 'main.Version=${VERSION}' -X 'main.Commit=${COMMIT}' -X 'main.BuildDate=${BUILD_DATE}'" -o ./CLIProxyAPI ./cmd/server/
+RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X 'main.Version=${VERSION}-plus' -X 'main.Commit=${COMMIT}' -X 'main.BuildDate=${BUILD_DATE}'" -o ./CLIProxyAPIPlus ./cmd/server/
FROM alpine:3.22.0
@@ -20,7 +20,7 @@ RUN apk add --no-cache tzdata
RUN mkdir /CLIProxyAPI
-COPY --from=builder ./app/CLIProxyAPI /CLIProxyAPI/CLIProxyAPI
+COPY --from=builder ./app/CLIProxyAPIPlus /CLIProxyAPI/CLIProxyAPIPlus
COPY config.example.yaml /CLIProxyAPI/config.example.yaml
@@ -32,4 +32,4 @@ ENV TZ=Asia/Shanghai
RUN cp /usr/share/zoneinfo/${TZ} /etc/localtime && echo "${TZ}" > /etc/timezone
-CMD ["./CLIProxyAPI"]
\ No newline at end of file
+CMD ["./CLIProxyAPIPlus"]
\ No newline at end of file
diff --git a/README.md b/README.md
index 4fa495c6..2d950a4c 100644
--- a/README.md
+++ b/README.md
@@ -1,168 +1,99 @@
-# CLI Proxy API
+# CLIProxyAPI Plus
-English | [中文](README_CN.md)
+English | [Chinese](README_CN.md)
-A proxy server that provides OpenAI/Gemini/Claude/Codex compatible API interfaces for CLI.
+This is the Plus version of [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI), adding support for third-party providers on top of the mainline project.
-It now also supports OpenAI Codex (GPT models) and Claude Code via OAuth.
+All third-party provider support is maintained by community contributors; CLIProxyAPI does not provide technical support. Please contact the corresponding community maintainer if you need assistance.
-So you can use local or multi-account CLI access with OpenAI(include Responses)/Gemini/Claude-compatible clients and SDKs.
+The Plus release stays in lockstep with the mainline features.
-## Sponsor
+## Differences from the Mainline
-[](https://z.ai/subscribe?ic=8JVLJQFSKB)
+- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
+- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)
-This project is sponsored by Z.ai, supporting us with their GLM CODING PLAN.
+## New Features (Plus Enhanced)
-GLM CODING PLAN is a subscription service designed for AI coding, starting at just $3/month. It provides access to their flagship GLM-4.7 model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
+- **OAuth Web Authentication**: Browser-based OAuth login for Kiro with beautiful web UI
+- **Rate Limiter**: Built-in request rate limiting to prevent API abuse
+- **Background Token Refresh**: Automatic token refresh 10 minutes before expiration
+- **Metrics & Monitoring**: Request metrics collection for monitoring and debugging
+- **Device Fingerprint**: Device fingerprint generation for enhanced security
+- **Cooldown Management**: Smart cooldown mechanism for API rate limits
+- **Usage Checker**: Real-time usage monitoring and quota management
+- **Model Converter**: Unified model name conversion across providers
+- **UTF-8 Stream Processing**: Improved streaming response handling
-Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB
+## Kiro Authentication
----
+### Web-based OAuth Login
-
-
-
- |
-Thanks to PackyCode for sponsoring this project! PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. PackyCode provides special discounts for our software users: register using this link and enter the "cliproxyapi" promo code during recharge to get 10% off. |
-
-
- |
-Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for CLIProxyAPI users: register via this link to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off! |
-
-
-
+Access the Kiro OAuth web interface at:
-## Overview
+```
+http://your-server:8080/v0/oauth/kiro
+```
-- OpenAI/Gemini/Claude compatible API endpoints for CLI models
-- OpenAI Codex support (GPT models) via OAuth login
-- Claude Code support via OAuth login
-- Qwen Code support via OAuth login
-- iFlow support via OAuth login
-- Amp CLI and IDE extensions support with provider routing
-- Streaming and non-streaming responses
-- Function calling/tools support
-- Multimodal input support (text and images)
-- Multiple accounts with round-robin load balancing (Gemini, OpenAI, Claude, Qwen and iFlow)
-- Simple CLI authentication flows (Gemini, OpenAI, Claude, Qwen and iFlow)
-- Generative Language API Key support
-- AI Studio Build multi-account load balancing
-- Gemini CLI multi-account load balancing
-- Claude Code multi-account load balancing
-- Qwen Code multi-account load balancing
-- iFlow multi-account load balancing
-- OpenAI Codex multi-account load balancing
-- OpenAI-compatible upstream providers via config (e.g., OpenRouter)
-- Reusable Go SDK for embedding the proxy (see `docs/sdk-usage.md`)
+This provides a browser-based OAuth flow for Kiro (AWS CodeWhisperer) authentication with:
+- AWS Builder ID login
+- AWS Identity Center (IDC) login
+- Token import from Kiro IDE
-## Getting Started
+## Quick Deployment with Docker
-CLIProxyAPI Guides: [https://help.router-for.me/](https://help.router-for.me/)
+### One-Command Deployment
-## Management API
+```bash
+# Create deployment directory
+mkdir -p ~/cli-proxy && cd ~/cli-proxy
-see [MANAGEMENT_API.md](https://help.router-for.me/management/api)
+# Create docker-compose.yml
+cat > docker-compose.yml << 'EOF'
+services:
+ cli-proxy-api:
+ image: eceasy/cli-proxy-api-plus:latest
+ container_name: cli-proxy-api-plus
+ ports:
+ - "8317:8317"
+ volumes:
+ - ./config.yaml:/CLIProxyAPI/config.yaml
+ - ./auths:/root/.cli-proxy-api
+ - ./logs:/CLIProxyAPI/logs
+ restart: unless-stopped
+EOF
-## Amp CLI Support
+# Download example config
+curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
-CLIProxyAPI includes integrated support for [Amp CLI](https://ampcode.com) and Amp IDE extensions, enabling you to use your Google/ChatGPT/Claude OAuth subscriptions with Amp's coding tools:
+# Pull and start
+docker compose pull && docker compose up -d
+```
-- Provider route aliases for Amp's API patterns (`/api/provider/{provider}/v1...`)
-- Management proxy for OAuth authentication and account features
-- Smart model fallback with automatic routing
-- **Model mapping** to route unavailable models to alternatives (e.g., `claude-opus-4.5` → `claude-sonnet-4`)
-- Security-first design with localhost-only management endpoints
+### Configuration
-**→ [Complete Amp CLI Integration Guide](https://help.router-for.me/agent-client/amp-cli.html)**
+Edit `config.yaml` before starting:
-## SDK Docs
+```yaml
+# Basic configuration example
+server:
+ port: 8317
-- Usage: [docs/sdk-usage.md](docs/sdk-usage.md)
-- Advanced (executors & translators): [docs/sdk-advanced.md](docs/sdk-advanced.md)
-- Access: [docs/sdk-access.md](docs/sdk-access.md)
-- Watcher: [docs/sdk-watcher.md](docs/sdk-watcher.md)
-- Custom Provider Example: `examples/custom-provider`
+# Add your provider configurations here
+```
+
+### Update to Latest Version
+
+```bash
+cd ~/cli-proxy
+docker compose pull && docker compose up -d
+```
## Contributing
-Contributions are welcome! Please feel free to submit a Pull Request.
+This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
-1. Fork the repository
-2. Create your feature branch (`git checkout -b feature/amazing-feature`)
-3. Commit your changes (`git commit -m 'Add some amazing feature'`)
-4. Push to the branch (`git push origin feature/amazing-feature`)
-5. Open a Pull Request
-
-## Who is with us?
-
-Those projects are based on CLIProxyAPI:
-
-### [vibeproxy](https://github.com/automazeio/vibeproxy)
-
-Native macOS menu bar app to use your Claude Code & ChatGPT subscriptions with AI coding tools - no API keys needed
-
-### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
-
-Browser-based tool to translate SRT subtitles using your Gemini subscription via CLIProxyAPI with automatic validation/error correction - no API keys needed
-
-### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
-
-CLI wrapper for instant switching between multiple Claude accounts and alternative models (Gemini, Codex, Antigravity) via CLIProxyAPI OAuth - no API keys needed
-
-### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
-
-Native macOS GUI for managing CLIProxyAPI: configure providers, model mappings, and endpoints via OAuth - no API keys needed.
-
-### [Quotio](https://github.com/nguyenphutrong/quotio)
-
-Native macOS menu bar app that unifies Claude, Gemini, OpenAI, Qwen, and Antigravity subscriptions with real-time quota tracking and smart auto-failover for AI coding tools like Claude Code, OpenCode, and Droid - no API keys needed.
-
-### [CodMate](https://github.com/loocor/CodMate)
-
-Native macOS SwiftUI app for managing CLI AI sessions (Codex, Claude Code, Gemini CLI) with unified provider management, Git review, project organization, global search, and terminal integration. Integrates CLIProxyAPI to provide OAuth authentication for Codex, Claude, Gemini, Antigravity, and Qwen Code, with built-in and third-party provider rerouting through a single proxy endpoint - no API keys needed for OAuth providers.
-
-### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
-
-Windows-native CLIProxyAPI fork with TUI, system tray, and multi-provider OAuth for AI coding tools - no API keys needed.
-
-### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
-
-VSCode extension for quick switching between Claude Code models, featuring integrated CLIProxyAPI as its backend with automatic background lifecycle management.
-
-### [ZeroLimit](https://github.com/0xtbug/zero-limit)
-
-Windows desktop app built with Tauri + React for monitoring AI coding assistant quotas via CLIProxyAPI. Track usage across Gemini, Claude, OpenAI Codex, and Antigravity accounts with real-time dashboard, system tray integration, and one-click proxy control - no API keys needed.
-
-### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
-
-A lightweight web admin panel for CLIProxyAPI with health checks, resource monitoring, real-time logs, auto-update, request statistics and pricing display. Supports one-click installation and systemd service.
-
-### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
-
-A Windows tray application implemented using PowerShell scripts, without relying on any third-party libraries. The main features include: automatic creation of shortcuts, silent running, password management, channel switching (Main / Plus), and automatic downloading and updating.
-
-### [霖君](https://github.com/wangdabaoqq/LinJun)
-
-霖君 is a cross-platform desktop application for managing AI programming assistants, supporting macOS, Windows, and Linux systems. Unified management of Claude Code, Gemini CLI, OpenAI Codex, Qwen Code, and other AI coding tools, with local proxy for multi-account quota tracking and one-click configuration.
-
-### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
-
-A modern web-based management dashboard for CLIProxyAPI built with Next.js, React, and PostgreSQL. Features real-time log streaming, structured configuration editing, API key management, OAuth provider integration for Claude/Gemini/Codex, usage analytics, container management, and config sync with OpenCode via companion plugin - no manual YAML editing needed.
-
-> [!NOTE]
-> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
-
-## More choices
-
-Those projects are ports of CLIProxyAPI or inspired by it:
-
-### [9Router](https://github.com/decolua/9router)
-
-A Next.js implementation inspired by CLIProxyAPI, easy to install and use, built from scratch with format translation (OpenAI/Claude/Gemini/Ollama), combo system with auto-fallback, multi-account management with exponential backoff, a Next.js web dashboard, and support for CLI tools (Cursor, Claude Code, Cline, RooCode) - no API keys needed.
-
-> [!NOTE]
-> If you have developed a port of CLIProxyAPI or a project inspired by it, please open a PR to add it to this list.
+If you need to submit any non-third-party provider changes, please open them against the [mainline](https://github.com/router-for-me/CLIProxyAPI) repository.
## License
diff --git a/README_CN.md b/README_CN.md
index 5c91cbdc..79b5203f 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -1,176 +1,100 @@
-# CLI 代理 API
+# CLIProxyAPI Plus
[English](README.md) | 中文
-一个为 CLI 提供 OpenAI/Gemini/Claude/Codex 兼容 API 接口的代理服务器。
+这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。
-现已支持通过 OAuth 登录接入 OpenAI Codex(GPT 系列)和 Claude Code。
+所有的第三方供应商支持都由第三方社区维护者提供,CLIProxyAPI 不提供技术支持。如需取得支持,请与对应的社区维护者联系。
-您可以使用本地或多账户的CLI方式,通过任何与 OpenAI(包括Responses)/Gemini/Claude 兼容的客户端和SDK进行访问。
+该 Plus 版本的主线功能与主线功能强制同步。
-## 赞助商
+## 与主线版本版本差异
-[](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
+- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
+- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供
-本项目由 Z智谱 提供赞助, 他们通过 GLM CODING PLAN 对本项目提供技术支持。
+## 新增功能 (Plus 增强版)
-GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7,为开发者提供顶尖的编码体验。
+- **OAuth Web 认证**: 基于浏览器的 Kiro OAuth 登录,提供美观的 Web UI
+- **请求限流器**: 内置请求限流,防止 API 滥用
+- **后台令牌刷新**: 过期前 10 分钟自动刷新令牌
+- **监控指标**: 请求指标收集,用于监控和调试
+- **设备指纹**: 设备指纹生成,增强安全性
+- **冷却管理**: 智能冷却机制,应对 API 速率限制
+- **用量检查器**: 实时用量监控和配额管理
+- **模型转换器**: 跨供应商的统一模型名称转换
+- **UTF-8 流处理**: 改进的流式响应处理
-智谱AI为本软件提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
+## Kiro 认证
----
+### 网页端 OAuth 登录
-
-
-
- |
-感谢 PackyCode 对本项目的赞助!PackyCode 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。PackyCode 为本软件用户提供了特别优惠:使用此链接注册,并在充值时输入 "cliproxyapi" 优惠码即可享受九折优惠。 |
-
-
- |
-感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定中转服务,支持企业级高并发、极速开票、7×24 专属技术支持。 Claude Code / Codex / Gemini 官方渠道低至 3.8 / 0.2 / 0.9 折,充值更有折上折!AICodeMirror 为 CLIProxyAPI 的用户提供了特别福利,通过此链接注册的用户,可享受首充8折,企业客户最高可享 7.5 折! |
-
-
-
+访问 Kiro OAuth 网页认证界面:
+```
+http://your-server:8080/v0/oauth/kiro
+```
-## 功能特性
+提供基于浏览器的 Kiro (AWS CodeWhisperer) OAuth 认证流程,支持:
+- AWS Builder ID 登录
+- AWS Identity Center (IDC) 登录
+- 从 Kiro IDE 导入令牌
-- 为 CLI 模型提供 OpenAI/Gemini/Claude/Codex 兼容的 API 端点
-- 新增 OpenAI Codex(GPT 系列)支持(OAuth 登录)
-- 新增 Claude Code 支持(OAuth 登录)
-- 新增 Qwen Code 支持(OAuth 登录)
-- 新增 iFlow 支持(OAuth 登录)
-- 支持流式与非流式响应
-- 函数调用/工具支持
-- 多模态输入(文本、图片)
-- 多账户支持与轮询负载均衡(Gemini、OpenAI、Claude、Qwen 与 iFlow)
-- 简单的 CLI 身份验证流程(Gemini、OpenAI、Claude、Qwen 与 iFlow)
-- 支持 Gemini AIStudio API 密钥
-- 支持 AI Studio Build 多账户轮询
-- 支持 Gemini CLI 多账户轮询
-- 支持 Claude Code 多账户轮询
-- 支持 Qwen Code 多账户轮询
-- 支持 iFlow 多账户轮询
-- 支持 OpenAI Codex 多账户轮询
-- 通过配置接入上游 OpenAI 兼容提供商(例如 OpenRouter)
-- 可复用的 Go SDK(见 `docs/sdk-usage_CN.md`)
+## Docker 快速部署
-## 新手入门
+### 一键部署
-CLIProxyAPI 用户手册: [https://help.router-for.me/](https://help.router-for.me/cn/)
+```bash
+# 创建部署目录
+mkdir -p ~/cli-proxy && cd ~/cli-proxy
-## 管理 API 文档
+# 创建 docker-compose.yml
+cat > docker-compose.yml << 'EOF'
+services:
+ cli-proxy-api:
+ image: eceasy/cli-proxy-api-plus:latest
+ container_name: cli-proxy-api-plus
+ ports:
+ - "8317:8317"
+ volumes:
+ - ./config.yaml:/CLIProxyAPI/config.yaml
+ - ./auths:/root/.cli-proxy-api
+ - ./logs:/CLIProxyAPI/logs
+ restart: unless-stopped
+EOF
-请参见 [MANAGEMENT_API_CN.md](https://help.router-for.me/cn/management/api)
+# 下载示例配置
+curl -o config.yaml https://raw.githubusercontent.com/router-for-me/CLIProxyAPIPlus/main/config.example.yaml
-## Amp CLI 支持
+# 拉取并启动
+docker compose pull && docker compose up -d
+```
-CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支持,可让你使用自己的 Google/ChatGPT/Claude OAuth 订阅来配合 Amp 编码工具:
+### 配置说明
-- 提供商路由别名,兼容 Amp 的 API 路径模式(`/api/provider/{provider}/v1...`)
-- 管理代理,处理 OAuth 认证和账号功能
-- 智能模型回退与自动路由
-- 以安全为先的设计,管理端点仅限 localhost
+启动前请编辑 `config.yaml`:
-**→ [Amp CLI 完整集成指南](https://help.router-for.me/cn/agent-client/amp-cli.html)**
+```yaml
+# 基本配置示例
+server:
+ port: 8317
-## SDK 文档
+# 在此添加你的供应商配置
+```
-- 使用文档:[docs/sdk-usage_CN.md](docs/sdk-usage_CN.md)
-- 高级(执行器与翻译器):[docs/sdk-advanced_CN.md](docs/sdk-advanced_CN.md)
-- 认证: [docs/sdk-access_CN.md](docs/sdk-access_CN.md)
-- 凭据加载/更新: [docs/sdk-watcher_CN.md](docs/sdk-watcher_CN.md)
-- 自定义 Provider 示例:`examples/custom-provider`
+### 更新到最新版本
+
+```bash
+cd ~/cli-proxy
+docker compose pull && docker compose up -d
+```
## 贡献
-欢迎贡献!请随时提交 Pull Request。
+该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
-1. Fork 仓库
-2. 创建您的功能分支(`git checkout -b feature/amazing-feature`)
-3. 提交您的更改(`git commit -m 'Add some amazing feature'`)
-4. 推送到分支(`git push origin feature/amazing-feature`)
-5. 打开 Pull Request
-
-## 谁与我们在一起?
-
-这些项目基于 CLIProxyAPI:
-
-### [vibeproxy](https://github.com/automazeio/vibeproxy)
-
-一个原生 macOS 菜单栏应用,让您可以使用 Claude Code & ChatGPT 订阅服务和 AI 编程工具,无需 API 密钥。
-
-### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
-
-一款基于浏览器的 SRT 字幕翻译工具,可通过 CLI 代理 API 使用您的 Gemini 订阅。内置自动验证与错误修正功能,无需 API 密钥。
-
-### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
-
-CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户和替代模型(Gemini, Codex, Antigravity),无需 API 密钥。
-
-### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
-
-基于 macOS 平台的原生 CLIProxyAPI GUI:配置供应商、模型映射以及OAuth端点,无需 API 密钥。
-
-### [Quotio](https://github.com/nguyenphutrong/quotio)
-
-原生 macOS 菜单栏应用,统一管理 Claude、Gemini、OpenAI、Qwen 和 Antigravity 订阅,提供实时配额追踪和智能自动故障转移,支持 Claude Code、OpenCode 和 Droid 等 AI 编程工具,无需 API 密钥。
-
-### [CodMate](https://github.com/loocor/CodMate)
-
-原生 macOS SwiftUI 应用,用于管理 CLI AI 会话(Claude Code、Codex、Gemini CLI),提供统一的提供商管理、Git 审查、项目组织、全局搜索和终端集成。集成 CLIProxyAPI 为 Codex、Claude、Gemini、Antigravity 和 Qwen Code 提供统一的 OAuth 认证,支持内置和第三方提供商通过单一代理端点重路由 - OAuth 提供商无需 API 密钥。
-
-### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
-
-原生 Windows CLIProxyAPI 分支,集成 TUI、系统托盘及多服务商 OAuth 认证,专为 AI 编程工具打造,无需 API 密钥。
-
-### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
-
-一款 VSCode 扩展,提供了在 VSCode 中快速切换 Claude Code 模型的功能,内置 CLIProxyAPI 作为其后端,支持后台自动启动和关闭。
-
-### [ZeroLimit](https://github.com/0xtbug/zero-limit)
-
-Windows 桌面应用,基于 Tauri + React 构建,用于通过 CLIProxyAPI 监控 AI 编程助手配额。支持跨 Gemini、Claude、OpenAI Codex 和 Antigravity 账户的使用量追踪,提供实时仪表盘、系统托盘集成和一键代理控制,无需 API 密钥。
-
-### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
-
-面向 CLIProxyAPI 的 Web 管理面板,提供健康检查、资源监控、日志查看、自动更新、请求统计与定价展示,支持一键安装与 systemd 服务。
-
-### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
-
-Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方库。主要功能包括:自动创建快捷方式、静默运行、密码管理、通道切换(Main / Plus)以及自动下载与更新。
-
-### [霖君](https://github.com/wangdabaoqq/LinJun)
-
-霖君是一款用于管理AI编程助手的跨平台桌面应用,支持macOS、Windows、Linux系统。统一管理Claude Code、Gemini CLI、OpenAI Codex、Qwen Code等AI编程工具,本地代理实现多账户配额跟踪和一键配置。
-
-### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
-
-一个面向 CLIProxyAPI 的现代化 Web 管理仪表盘,基于 Next.js、React 和 PostgreSQL 构建。支持实时日志流、结构化配置编辑、API Key 管理、Claude/Gemini/Codex 的 OAuth 提供方集成、使用量分析、容器管理,并可通过配套插件与 OpenCode 同步配置,无需手动编辑 YAML。
-
-> [!NOTE]
-> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
-
-## 更多选择
-
-以下项目是 CLIProxyAPI 的移植版或受其启发:
-
-### [9Router](https://github.com/decolua/9router)
-
-基于 Next.js 的实现,灵感来自 CLIProxyAPI,易于安装使用;自研格式转换(OpenAI/Claude/Gemini/Ollama)、组合系统与自动回退、多账户管理(指数退避)、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
-
-> [!NOTE]
-> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
+如果需要提交任何非第三方供应商支持的 Pull Request,请提交到[主线](https://github.com/router-for-me/CLIProxyAPI)版本。
## 许可证
-此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
-
-## 写给所有中国网友的
-
-QQ 群:188637136
-
-或
-
-Telegram 群:https://t.me/CLIProxyAPI
+此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
\ No newline at end of file
diff --git a/cmd/server/main.go b/cmd/server/main.go
index 684d9295..2ef8c339 100644
--- a/cmd/server/main.go
+++ b/cmd/server/main.go
@@ -18,6 +18,7 @@ import (
"github.com/joho/godotenv"
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cmd"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -49,6 +50,19 @@ func init() {
buildinfo.BuildDate = BuildDate
}
+// setKiroIncognitoMode sets the incognito browser mode for Kiro authentication.
+// Kiro defaults to incognito mode for multi-account support.
+// Users can explicitly override with --incognito or --no-incognito flags.
+func setKiroIncognitoMode(cfg *config.Config, useIncognito, noIncognito bool) {
+ if useIncognito {
+ cfg.IncognitoBrowser = true
+ } else if noIncognito {
+ cfg.IncognitoBrowser = false
+ } else {
+ cfg.IncognitoBrowser = true // Kiro default
+ }
+}
+
// main is the entry point of the application.
// It parses command-line flags, loads configuration, and starts the appropriate
// service based on the provided flags (login, codex-login, or server mode).
@@ -60,30 +74,48 @@ func main() {
var codexLogin bool
var claudeLogin bool
var qwenLogin bool
+ var kiloLogin bool
var iflowLogin bool
var iflowCookie bool
var noBrowser bool
var oauthCallbackPort int
var antigravityLogin bool
var kimiLogin bool
+ var kiroLogin bool
+ var kiroGoogleLogin bool
+ var kiroAWSLogin bool
+ var kiroAWSAuthCode bool
+ var kiroImport bool
+ var githubCopilotLogin bool
var projectID string
var vertexImport string
var configPath string
var password string
var tuiMode bool
var standalone bool
+ var noIncognito bool
+ var useIncognito bool
// Define command-line flags for different operation modes.
flag.BoolVar(&login, "login", false, "Login Google Account")
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
+ flag.BoolVar(&kiloLogin, "kilo-login", false, "Login to Kilo AI using device flow")
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
+ flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
+ flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
+ flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
+ flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
+ flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
+ flag.BoolVar(&kiroAWSAuthCode, "kiro-aws-authcode", false, "Login to Kiro using AWS Builder ID (authorization code flow, better UX)")
+ flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)")
+ flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
@@ -464,6 +496,9 @@ func main() {
} else if antigravityLogin {
// Handle Antigravity login
cmd.DoAntigravityLogin(cfg, options)
+ } else if githubCopilotLogin {
+ // Handle GitHub Copilot login
+ cmd.DoGitHubCopilotLogin(cfg, options)
} else if codexLogin {
// Handle Codex login
cmd.DoCodexLogin(cfg, options)
@@ -472,12 +507,38 @@ func main() {
cmd.DoClaudeLogin(cfg, options)
} else if qwenLogin {
cmd.DoQwenLogin(cfg, options)
+ } else if kiloLogin {
+ cmd.DoKiloLogin(cfg, options)
} else if iflowLogin {
cmd.DoIFlowLogin(cfg, options)
} else if iflowCookie {
cmd.DoIFlowCookieAuth(cfg, options)
} else if kimiLogin {
cmd.DoKimiLogin(cfg, options)
+ } else if kiroLogin {
+ // For Kiro auth, default to incognito mode for multi-account support
+ // Users can explicitly override with --no-incognito
+ // Note: This config mutation is safe - auth commands exit after completion
+ // and don't share config with StartService (which is in the else branch)
+ setKiroIncognitoMode(cfg, useIncognito, noIncognito)
+ cmd.DoKiroLogin(cfg, options)
+ } else if kiroGoogleLogin {
+ // For Kiro auth, default to incognito mode for multi-account support
+ // Users can explicitly override with --no-incognito
+ // Note: This config mutation is safe - auth commands exit after completion
+ setKiroIncognitoMode(cfg, useIncognito, noIncognito)
+ cmd.DoKiroGoogleLogin(cfg, options)
+ } else if kiroAWSLogin {
+ // For Kiro auth, default to incognito mode for multi-account support
+ // Users can explicitly override with --no-incognito
+ setKiroIncognitoMode(cfg, useIncognito, noIncognito)
+ cmd.DoKiroAWSLogin(cfg, options)
+ } else if kiroAWSAuthCode {
+ // For Kiro auth with authorization code flow (better UX)
+ setKiroIncognitoMode(cfg, useIncognito, noIncognito)
+ cmd.DoKiroAWSAuthCodeLogin(cfg, options)
+ } else if kiroImport {
+ cmd.DoKiroImport(cfg, options)
} else {
// In cloud deploy mode without config file, just wait for shutdown signals
if isCloudDeploy && !configFileExists {
@@ -559,9 +620,15 @@ func main() {
}
}
} else {
- // Start the main proxy service
- managementasset.StartAutoUpdater(context.Background(), configFilePath)
- cmd.StartService(cfg, configFilePath, password)
+ // Start the main proxy service
+ managementasset.StartAutoUpdater(context.Background(), configFilePath)
+
+ if cfg.AuthDir != "" {
+ kiro.InitializeAndStart(cfg.AuthDir, cfg)
+ defer kiro.StopGlobalRefreshManager()
+ }
+
+ cmd.StartService(cfg, configFilePath, password)
}
}
}
diff --git a/config.example.yaml b/config.example.yaml
index 92619493..d86a8aef 100644
--- a/config.example.yaml
+++ b/config.example.yaml
@@ -1,6 +1,6 @@
# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6).
# Use "127.0.0.1" or "localhost" to restrict access to local machine only.
-host: ""
+host: ''
# Server port
port: 8317
@@ -8,8 +8,8 @@ port: 8317
# TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key.
tls:
enable: false
- cert: ""
- key: ""
+ cert: ''
+ key: ''
# Management API settings
remote-management:
@@ -20,22 +20,22 @@ remote-management:
# Management key. If a plaintext value is provided here, it will be hashed on startup.
# All management requests (even from localhost) require this key.
# Leave empty to disable the Management API entirely (404 for all /v0/management routes).
- secret-key: ""
+ secret-key: ''
# Disable the bundled management control panel asset download and HTTP route when true.
disable-control-panel: false
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
- panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
+ panel-github-repository: 'https://github.com/router-for-me/Cli-Proxy-API-Management-Center'
# Authentication directory (supports ~ for home directory)
-auth-dir: "~/.cli-proxy-api"
+auth-dir: '~/.cli-proxy-api'
# API keys for authentication
api-keys:
- - "your-api-key-1"
- - "your-api-key-2"
- - "your-api-key-3"
+ - 'your-api-key-1'
+ - 'your-api-key-2'
+ - 'your-api-key-3'
# Enable debug logging
debug: false
@@ -43,11 +43,16 @@ debug: false
# Enable pprof HTTP debug server (host:port). Keep it bound to localhost for safety.
pprof:
enable: false
- addr: "127.0.0.1:8316"
+ addr: '127.0.0.1:8316'
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
commercial-mode: false
+# Open OAuth URLs in incognito/private browser mode.
+# Useful when you want to login with a different account without logging out from your current session.
+# Default: false (but Kiro auth defaults to true for multi-account support)
+incognito-browser: true
+
# When true, write application logs to rotating files instead of stdout
logging-to-file: false
@@ -63,7 +68,7 @@ error-logs-max-files: 10
usage-statistics-enabled: false
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
-proxy-url: ""
+proxy-url: ''
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
force-model-prefix: false
@@ -81,7 +86,7 @@ quota-exceeded:
# Routing strategy for selecting credentials when multiple match.
routing:
- strategy: "round-robin" # round-robin (default), fill-first
+ strategy: 'round-robin' # round-robin (default), fill-first
# When true, enable authentication for the WebSocket API (/v1/ws).
ws-auth: false
@@ -164,6 +169,31 @@ nonstream-keepalive-interval: 0
# runtime-version: "v24.3.0"
# timeout: "600"
+# Kiro (AWS CodeWhisperer) configuration
+# Note: Kiro API currently only operates in us-east-1 region
+#kiro:
+# - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file
+# agent-task-type: "" # optional: "vibe" or empty (API default)
+# - access-token: "aoaAAAAA..." # or provide tokens directly
+# refresh-token: "aorAAAAA..."
+# profile-arn: "arn:aws:codewhisperer:us-east-1:..."
+# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override
+
+# Kilocode (OAuth-based code assistant)
+# Note: Kilocode uses OAuth device flow authentication.
+# Use the CLI command: ./server --kilo-login
+# This will save credentials to the auth directory (default: ~/.cli-proxy-api/)
+# oauth-model-alias:
+# kilo:
+# - name: "minimax/minimax-m2.5:free"
+# alias: "minimax-m2.5"
+# - name: "z-ai/glm-5:free"
+# alias: "glm-5"
+# oauth-excluded-models:
+# kilo:
+# - "kilo-claude-opus-4-6" # exclude specific models (exact match)
+# - "*:free" # wildcard matching suffix (e.g. all free models)
+
# OpenAI compatibility providers
# openai-compatibility:
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
@@ -229,10 +259,25 @@ nonstream-keepalive-interval: 0
# Global OAuth model name aliases (per channel)
# These aliases rename model IDs for both model listing and request routing.
-# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kimi.
+# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# You can repeat the same name with different aliases to expose multiple client model names.
# oauth-model-alias:
+# antigravity:
+# - name: "rev19-uic3-1p"
+# alias: "gemini-2.5-computer-use-preview-10-2025"
+# - name: "gemini-3-pro-image"
+# alias: "gemini-3-pro-image-preview"
+# - name: "gemini-3-pro-high"
+# alias: "gemini-3-pro-preview"
+# - name: "gemini-3-flash"
+# alias: "gemini-3-flash-preview"
+# - name: "claude-sonnet-4-5"
+# alias: "gemini-claude-sonnet-4-5"
+# - name: "claude-sonnet-4-5-thinking"
+# alias: "gemini-claude-sonnet-4-5-thinking"
+# - name: "claude-opus-4-5-thinking"
+# alias: "gemini-claude-opus-4-5-thinking"
# gemini-cli:
# - name: "gemini-2.5-pro" # original model name under this channel
# alias: "g2.5p" # client-visible alias
@@ -243,9 +288,6 @@ nonstream-keepalive-interval: 0
# aistudio:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
-# antigravity:
-# - name: "gemini-3-pro-high"
-# alias: "gemini-3-pro-preview"
# claude:
# - name: "claude-sonnet-4-5-20250929"
# alias: "cs4.5"
@@ -261,8 +303,15 @@ nonstream-keepalive-interval: 0
# kimi:
# - name: "kimi-k2.5"
# alias: "k2.5"
+# kiro:
+# - name: "kiro-claude-opus-4-5"
+# alias: "op45"
+# github-copilot:
+# - name: "gpt-5"
+# alias: "copilot-gpt5"
# OAuth provider excluded models
+# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
# oauth-excluded-models:
# gemini-cli:
# - "gemini-2.5-pro" # exclude specific models (exact match)
@@ -285,6 +334,10 @@ nonstream-keepalive-interval: 0
# - "tstars2.0"
# kimi:
# - "kimi-k2-thinking"
+# kiro:
+# - "kiro-claude-haiku-4-5"
+# github-copilot:
+# - "raptor-mini"
# Optional payload configuration
# payload:
diff --git a/docker-compose.yml b/docker-compose.yml
index ad2190c2..cd8c21b9 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -1,6 +1,6 @@
services:
cli-proxy-api:
- image: ${CLI_PROXY_IMAGE:-eceasy/cli-proxy-api:latest}
+ image: ${CLI_PROXY_IMAGE:-eceasy/cli-proxy-api-plus:latest}
pull_policy: always
build:
context: .
@@ -9,7 +9,7 @@ services:
VERSION: ${VERSION:-dev}
COMMIT: ${COMMIT:-none}
BUILD_DATE: ${BUILD_DATE:-unknown}
- container_name: cli-proxy-api
+ container_name: cli-proxy-api-plus
# env_file:
# - .env
environment:
diff --git a/go.mod b/go.mod
index 34237de9..461d5517 100644
--- a/go.mod
+++ b/go.mod
@@ -9,6 +9,7 @@ require (
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/fsnotify/fsnotify v1.9.0
+ github.com/fxamacker/cbor/v2 v2.9.0
github.com/gin-gonic/gin v1.10.1
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145
github.com/google/uuid v1.6.0
@@ -17,9 +18,9 @@ require (
github.com/joho/godotenv v1.5.1
github.com/klauspost/compress v1.17.4
github.com/minio/minio-go/v7 v7.0.66
+ github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/refraction-networking/utls v1.8.2
github.com/sirupsen/logrus v1.9.3
- github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/tiktoken-go/tokenizer v0.7.0
@@ -27,6 +28,7 @@ require (
golang.org/x/net v0.47.0
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.18.0
+ golang.org/x/term v0.37.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
)
@@ -90,6 +92,7 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
+ github.com/x448/float16 v0.8.4 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
diff --git a/go.sum b/go.sum
index 3c424c5e..8a4a967d 100644
--- a/go.sum
+++ b/go.sum
@@ -61,6 +61,8 @@ github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
+github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM=
+github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
@@ -154,6 +156,8 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
+github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
+github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
@@ -168,8 +172,6 @@ github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
-github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA=
-github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -201,6 +203,8 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
+github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
+github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
@@ -216,6 +220,7 @@ golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
diff --git a/internal/api/handlers/management/api_tools.go b/internal/api/handlers/management/api_tools.go
index c7846a75..666ff248 100644
--- a/internal/api/handlers/management/api_tools.go
+++ b/internal/api/handlers/management/api_tools.go
@@ -1,6 +1,7 @@
package management
import (
+ "bytes"
"context"
"encoding/json"
"fmt"
@@ -11,13 +12,15 @@ import (
"strings"
"time"
+ "github.com/fxamacker/cbor/v2"
"github.com/gin-gonic/gin"
- "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
- coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
+ coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
const defaultAPICallTimeout = 60 * time.Second
@@ -54,6 +57,7 @@ type apiCallResponse struct {
StatusCode int `json:"status_code"`
Header map[string][]string `json:"header"`
Body string `json:"body"`
+ Quota *QuotaSnapshots `json:"quota,omitempty"`
}
// APICall makes a generic HTTP request on behalf of the management API caller.
@@ -70,7 +74,7 @@ type apiCallResponse struct {
// - Authorization: Bearer
// - X-Management-Key:
//
-// Request JSON:
+// Request JSON (supports both application/json and application/cbor):
// - auth_index / authIndex / AuthIndex (optional):
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
// If omitted or not found, credential-specific proxy/token substitution is skipped.
@@ -90,10 +94,14 @@ type apiCallResponse struct {
// 2. Global config proxy-url
// 3. Direct connect (environment proxies are not used)
//
-// Response JSON (returned with HTTP 200 when the APICall itself succeeds):
-// - status_code: Upstream HTTP status code.
-// - header: Upstream response headers.
-// - body: Upstream response body as string.
+// Response (returned with HTTP 200 when the APICall itself succeeds):
+//
+// Format matches request Content-Type (application/json or application/cbor)
+// - status_code: Upstream HTTP status code.
+// - header: Upstream response headers.
+// - body: Upstream response body as string.
+// - quota (optional): For GitHub Copilot enterprise accounts, contains quota_snapshots
+// with details for chat, completions, and premium_interactions.
//
// Example:
//
@@ -107,10 +115,28 @@ type apiCallResponse struct {
// -H "Content-Type: application/json" \
// -d '{"auth_index":"","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
func (h *Handler) APICall(c *gin.Context) {
+ // Detect content type
+ contentType := strings.ToLower(strings.TrimSpace(c.GetHeader("Content-Type")))
+ isCBOR := strings.Contains(contentType, "application/cbor")
+
var body apiCallRequest
- if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
- return
+
+ // Parse request body based on content type
+ if isCBOR {
+ rawBody, errRead := io.ReadAll(c.Request.Body)
+ if errRead != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
+ return
+ }
+ if errUnmarshal := cbor.Unmarshal(rawBody, &body); errUnmarshal != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid cbor body"})
+ return
+ }
+ } else {
+ if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
+ return
+ }
}
method := strings.ToUpper(strings.TrimSpace(body.Method))
@@ -164,9 +190,21 @@ func (h *Handler) APICall(c *gin.Context) {
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
}
+ // When caller indicates CBOR in request headers, convert JSON string payload to CBOR bytes.
+ useCBORPayload := headerContainsValue(reqHeaders, "Content-Type", "application/cbor")
+
var requestBody io.Reader
if body.Data != "" {
- requestBody = strings.NewReader(body.Data)
+ if useCBORPayload {
+ cborPayload, errEncode := encodeJSONStringToCBOR(body.Data)
+ if errEncode != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json data for cbor content-type"})
+ return
+ }
+ requestBody = bytes.NewReader(cborPayload)
+ } else {
+ requestBody = strings.NewReader(body.Data)
+ }
}
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
@@ -209,11 +247,38 @@ func (h *Handler) APICall(c *gin.Context) {
return
}
- c.JSON(http.StatusOK, apiCallResponse{
+ // For CBOR upstream responses, decode into plain text or JSON string before returning.
+ responseBodyText := string(respBody)
+ if headerContainsValue(reqHeaders, "Accept", "application/cbor") || strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "application/cbor") {
+ if decodedBody, errDecode := decodeCBORBodyToTextOrJSON(respBody); errDecode == nil {
+ responseBodyText = decodedBody
+ }
+ }
+
+ response := apiCallResponse{
StatusCode: resp.StatusCode,
Header: resp.Header,
- Body: string(respBody),
- })
+ Body: responseBodyText,
+ }
+
+ // If this is a GitHub Copilot token endpoint response, try to enrich with quota information
+ if resp.StatusCode == http.StatusOK &&
+ strings.Contains(urlStr, "copilot_internal") &&
+ strings.Contains(urlStr, "/token") {
+ response = h.enrichCopilotTokenResponse(c.Request.Context(), response, auth, urlStr)
+ }
+
+ // Return response in the same format as the request
+ if isCBOR {
+ cborData, errMarshal := cbor.Marshal(response)
+ if errMarshal != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to encode cbor response"})
+ return
+ }
+ c.Data(http.StatusOK, "application/cbor", cborData)
+ } else {
+ c.JSON(http.StatusOK, response)
+ }
}
func firstNonEmptyString(values ...*string) string {
@@ -702,3 +767,421 @@ func buildProxyTransport(proxyStr string) *http.Transport {
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
return nil
}
+
+// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
+func headerContainsValue(headers map[string]string, targetKey, targetValue string) bool {
+ if len(headers) == 0 {
+ return false
+ }
+ for key, value := range headers {
+ if !strings.EqualFold(strings.TrimSpace(key), strings.TrimSpace(targetKey)) {
+ continue
+ }
+ if strings.Contains(strings.ToLower(value), strings.ToLower(strings.TrimSpace(targetValue))) {
+ return true
+ }
+ }
+ return false
+}
+
+// encodeJSONStringToCBOR converts a JSON string payload into CBOR bytes.
+func encodeJSONStringToCBOR(jsonString string) ([]byte, error) {
+ var payload any
+ if errUnmarshal := json.Unmarshal([]byte(jsonString), &payload); errUnmarshal != nil {
+ return nil, errUnmarshal
+ }
+ return cbor.Marshal(payload)
+}
+
+// decodeCBORBodyToTextOrJSON decodes CBOR bytes to plain text (for string payloads) or JSON string.
+func decodeCBORBodyToTextOrJSON(raw []byte) (string, error) {
+ if len(raw) == 0 {
+ return "", nil
+ }
+
+ var payload any
+ if errUnmarshal := cbor.Unmarshal(raw, &payload); errUnmarshal != nil {
+ return "", errUnmarshal
+ }
+
+ jsonCompatible := cborValueToJSONCompatible(payload)
+ switch typed := jsonCompatible.(type) {
+ case string:
+ return typed, nil
+ case []byte:
+ return string(typed), nil
+ default:
+ jsonBytes, errMarshal := json.Marshal(jsonCompatible)
+ if errMarshal != nil {
+ return "", errMarshal
+ }
+ return string(jsonBytes), nil
+ }
+}
+
+// cborValueToJSONCompatible recursively converts CBOR-decoded values into JSON-marshalable values.
+func cborValueToJSONCompatible(value any) any {
+ switch typed := value.(type) {
+ case map[any]any:
+ out := make(map[string]any, len(typed))
+ for key, item := range typed {
+ out[fmt.Sprint(key)] = cborValueToJSONCompatible(item)
+ }
+ return out
+ case map[string]any:
+ out := make(map[string]any, len(typed))
+ for key, item := range typed {
+ out[key] = cborValueToJSONCompatible(item)
+ }
+ return out
+ case []any:
+ out := make([]any, len(typed))
+ for i, item := range typed {
+ out[i] = cborValueToJSONCompatible(item)
+ }
+ return out
+ default:
+ return typed
+ }
+}
+
+// QuotaDetail represents quota information for a specific resource type
+type QuotaDetail struct {
+ Entitlement float64 `json:"entitlement"`
+ OverageCount float64 `json:"overage_count"`
+ OveragePermitted bool `json:"overage_permitted"`
+ PercentRemaining float64 `json:"percent_remaining"`
+ QuotaID string `json:"quota_id"`
+ QuotaRemaining float64 `json:"quota_remaining"`
+ Remaining float64 `json:"remaining"`
+ Unlimited bool `json:"unlimited"`
+}
+
+// QuotaSnapshots contains quota details for different resource types
+type QuotaSnapshots struct {
+ Chat QuotaDetail `json:"chat"`
+ Completions QuotaDetail `json:"completions"`
+ PremiumInteractions QuotaDetail `json:"premium_interactions"`
+}
+
+// CopilotUsageResponse represents the GitHub Copilot usage information
+type CopilotUsageResponse struct {
+ AccessTypeSKU string `json:"access_type_sku"`
+ AnalyticsTrackingID string `json:"analytics_tracking_id"`
+ AssignedDate string `json:"assigned_date"`
+ CanSignupForLimited bool `json:"can_signup_for_limited"`
+ ChatEnabled bool `json:"chat_enabled"`
+ CopilotPlan string `json:"copilot_plan"`
+ OrganizationLoginList []interface{} `json:"organization_login_list"`
+ OrganizationList []interface{} `json:"organization_list"`
+ QuotaResetDate string `json:"quota_reset_date"`
+ QuotaSnapshots QuotaSnapshots `json:"quota_snapshots"`
+}
+
+type copilotQuotaRequest struct {
+ AuthIndexSnake *string `json:"auth_index"`
+ AuthIndexCamel *string `json:"authIndex"`
+ AuthIndexPascal *string `json:"AuthIndex"`
+}
+
+// GetCopilotQuota fetches GitHub Copilot quota information from the /copilot_internal/user endpoint.
+//
+// Endpoint:
+//
+// GET /v0/management/copilot-quota
+//
+// Query Parameters (optional):
+// - auth_index: The credential "auth_index" from GET /v0/management/auth-files.
+// If omitted, uses the first available GitHub Copilot credential.
+//
+// Response:
+//
+// Returns the CopilotUsageResponse with quota_snapshots containing detailed quota information
+// for chat, completions, and premium_interactions.
+//
+// Example:
+//
+// curl -sS -X GET "http://127.0.0.1:8317/v0/management/copilot-quota?auth_index=" \
+// -H "Authorization: Bearer "
+func (h *Handler) GetCopilotQuota(c *gin.Context) {
+ authIndex := strings.TrimSpace(c.Query("auth_index"))
+ if authIndex == "" {
+ authIndex = strings.TrimSpace(c.Query("authIndex"))
+ }
+ if authIndex == "" {
+ authIndex = strings.TrimSpace(c.Query("AuthIndex"))
+ }
+
+ auth := h.findCopilotAuth(authIndex)
+ if auth == nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "no github copilot credential found"})
+ return
+ }
+
+ token, tokenErr := h.resolveTokenForAuth(c.Request.Context(), auth)
+ if tokenErr != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "failed to refresh copilot token"})
+ return
+ }
+ if token == "" {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "copilot token not found"})
+ return
+ }
+
+ apiURL := "https://api.github.com/copilot_internal/user"
+ req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, apiURL, nil)
+ if errNewRequest != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to build request"})
+ return
+ }
+
+ req.Header.Set("Authorization", "Bearer "+token)
+ req.Header.Set("User-Agent", "CLIProxyAPIPlus")
+ req.Header.Set("Accept", "application/json")
+
+ httpClient := &http.Client{
+ Timeout: defaultAPICallTimeout,
+ Transport: h.apiCallTransport(auth),
+ }
+
+ resp, errDo := httpClient.Do(req)
+ if errDo != nil {
+ log.WithError(errDo).Debug("copilot quota request failed")
+ c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
+ return
+ }
+ defer func() {
+ if errClose := resp.Body.Close(); errClose != nil {
+ log.Errorf("response body close error: %v", errClose)
+ }
+ }()
+
+ respBody, errReadAll := io.ReadAll(resp.Body)
+ if errReadAll != nil {
+ c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
+ return
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ c.JSON(http.StatusBadGateway, gin.H{
+ "error": "github api request failed",
+ "status_code": resp.StatusCode,
+ "body": string(respBody),
+ })
+ return
+ }
+
+ var usage CopilotUsageResponse
+ if errUnmarshal := json.Unmarshal(respBody, &usage); errUnmarshal != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to parse response"})
+ return
+ }
+
+ c.JSON(http.StatusOK, usage)
+}
+
+// findCopilotAuth locates a GitHub Copilot credential by auth_index or returns the first available one
+func (h *Handler) findCopilotAuth(authIndex string) *coreauth.Auth {
+ if h == nil || h.authManager == nil {
+ return nil
+ }
+
+ auths := h.authManager.List()
+ var firstCopilot *coreauth.Auth
+
+ for _, auth := range auths {
+ if auth == nil {
+ continue
+ }
+
+ provider := strings.ToLower(strings.TrimSpace(auth.Provider))
+ if provider != "copilot" && provider != "github" && provider != "github-copilot" {
+ continue
+ }
+
+ if firstCopilot == nil {
+ firstCopilot = auth
+ }
+
+ if authIndex != "" {
+ auth.EnsureIndex()
+ if auth.Index == authIndex {
+ return auth
+ }
+ }
+ }
+
+ return firstCopilot
+}
+
+// enrichCopilotTokenResponse fetches quota information and adds it to the Copilot token response body
+func (h *Handler) enrichCopilotTokenResponse(ctx context.Context, response apiCallResponse, auth *coreauth.Auth, originalURL string) apiCallResponse {
+ if auth == nil || response.Body == "" {
+ return response
+ }
+
+ // Parse the token response to check if it's enterprise (null limited_user_quotas)
+ var tokenResp map[string]interface{}
+ if err := json.Unmarshal([]byte(response.Body), &tokenResp); err != nil {
+ log.WithError(err).Debug("enrichCopilotTokenResponse: failed to parse copilot token response")
+ return response
+ }
+
+ // Get the GitHub token to call the copilot_internal/user endpoint
+ token, tokenErr := h.resolveTokenForAuth(ctx, auth)
+ if tokenErr != nil {
+ log.WithError(tokenErr).Debug("enrichCopilotTokenResponse: failed to resolve token")
+ return response
+ }
+ if token == "" {
+ return response
+ }
+
+ // Fetch quota information from /copilot_internal/user
+ // Derive the base URL from the original token request to support proxies and test servers
+ parsedURL, errParse := url.Parse(originalURL)
+ if errParse != nil {
+ log.WithError(errParse).Debug("enrichCopilotTokenResponse: failed to parse URL")
+ return response
+ }
+ quotaURL := fmt.Sprintf("%s://%s/copilot_internal/user", parsedURL.Scheme, parsedURL.Host)
+
+ req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodGet, quotaURL, nil)
+ if errNewRequest != nil {
+ log.WithError(errNewRequest).Debug("enrichCopilotTokenResponse: failed to build request")
+ return response
+ }
+
+ req.Header.Set("Authorization", "Bearer "+token)
+ req.Header.Set("User-Agent", "CLIProxyAPIPlus")
+ req.Header.Set("Accept", "application/json")
+
+ httpClient := &http.Client{
+ Timeout: defaultAPICallTimeout,
+ Transport: h.apiCallTransport(auth),
+ }
+
+ quotaResp, errDo := httpClient.Do(req)
+ if errDo != nil {
+ log.WithError(errDo).Debug("enrichCopilotTokenResponse: quota fetch HTTP request failed")
+ return response
+ }
+
+ defer func() {
+ if errClose := quotaResp.Body.Close(); errClose != nil {
+ log.Errorf("quota response body close error: %v", errClose)
+ }
+ }()
+
+ if quotaResp.StatusCode != http.StatusOK {
+ return response
+ }
+
+ quotaBody, errReadAll := io.ReadAll(quotaResp.Body)
+ if errReadAll != nil {
+ log.WithError(errReadAll).Debug("enrichCopilotTokenResponse: failed to read response")
+ return response
+ }
+
+ // Parse the quota response
+ var quotaData CopilotUsageResponse
+ if err := json.Unmarshal(quotaBody, "aData); err != nil {
+ log.WithError(err).Debug("enrichCopilotTokenResponse: failed to parse response")
+ return response
+ }
+
+ // Check if this is an enterprise account by looking for quota_snapshots in the response
+ // Enterprise accounts have quota_snapshots, non-enterprise have limited_user_quotas
+ var quotaRaw map[string]interface{}
+ if err := json.Unmarshal(quotaBody, "aRaw); err == nil {
+ if _, hasQuotaSnapshots := quotaRaw["quota_snapshots"]; hasQuotaSnapshots {
+ // Enterprise account - has quota_snapshots
+ tokenResp["quota_snapshots"] = quotaData.QuotaSnapshots
+ tokenResp["access_type_sku"] = quotaData.AccessTypeSKU
+ tokenResp["copilot_plan"] = quotaData.CopilotPlan
+
+ // Add quota reset date for enterprise (quota_reset_date_utc)
+ if quotaResetDateUTC, ok := quotaRaw["quota_reset_date_utc"]; ok {
+ tokenResp["quota_reset_date"] = quotaResetDateUTC
+ } else if quotaData.QuotaResetDate != "" {
+ tokenResp["quota_reset_date"] = quotaData.QuotaResetDate
+ }
+ } else {
+ // Non-enterprise account - build quota from limited_user_quotas and monthly_quotas
+ var quotaSnapshots QuotaSnapshots
+
+ // Get monthly quotas (total entitlement) and limited_user_quotas (remaining)
+ monthlyQuotas, hasMonthly := quotaRaw["monthly_quotas"].(map[string]interface{})
+ limitedQuotas, hasLimited := quotaRaw["limited_user_quotas"].(map[string]interface{})
+
+ // Process chat quota
+ if hasMonthly && hasLimited {
+ if chatTotal, ok := monthlyQuotas["chat"].(float64); ok {
+ chatRemaining := chatTotal // default to full if no limited quota
+ if chatLimited, ok := limitedQuotas["chat"].(float64); ok {
+ chatRemaining = chatLimited
+ }
+ percentRemaining := 0.0
+ if chatTotal > 0 {
+ percentRemaining = (chatRemaining / chatTotal) * 100.0
+ }
+ quotaSnapshots.Chat = QuotaDetail{
+ Entitlement: chatTotal,
+ Remaining: chatRemaining,
+ QuotaRemaining: chatRemaining,
+ PercentRemaining: percentRemaining,
+ QuotaID: "chat",
+ Unlimited: false,
+ }
+ }
+
+ // Process completions quota
+ if completionsTotal, ok := monthlyQuotas["completions"].(float64); ok {
+ completionsRemaining := completionsTotal // default to full if no limited quota
+ if completionsLimited, ok := limitedQuotas["completions"].(float64); ok {
+ completionsRemaining = completionsLimited
+ }
+ percentRemaining := 0.0
+ if completionsTotal > 0 {
+ percentRemaining = (completionsRemaining / completionsTotal) * 100.0
+ }
+ quotaSnapshots.Completions = QuotaDetail{
+ Entitlement: completionsTotal,
+ Remaining: completionsRemaining,
+ QuotaRemaining: completionsRemaining,
+ PercentRemaining: percentRemaining,
+ QuotaID: "completions",
+ Unlimited: false,
+ }
+ }
+ }
+
+ // Premium interactions don't exist for non-enterprise, leave as zero values
+ quotaSnapshots.PremiumInteractions = QuotaDetail{
+ QuotaID: "premium_interactions",
+ Unlimited: false,
+ }
+
+ // Add quota_snapshots to the token response
+ tokenResp["quota_snapshots"] = quotaSnapshots
+ tokenResp["access_type_sku"] = quotaData.AccessTypeSKU
+ tokenResp["copilot_plan"] = quotaData.CopilotPlan
+
+ // Add quota reset date for non-enterprise (limited_user_reset_date)
+ if limitedResetDate, ok := quotaRaw["limited_user_reset_date"]; ok {
+ tokenResp["quota_reset_date"] = limitedResetDate
+ }
+ }
+ }
+
+ // Re-serialize the enriched response
+ enrichedBody, errMarshal := json.Marshal(tokenResp)
+ if errMarshal != nil {
+ log.WithError(errMarshal).Debug("failed to marshal enriched response")
+ return response
+ }
+
+ response.Body = string(enrichedBody)
+
+ return response
+}
diff --git a/internal/api/handlers/management/api_tools_cbor_test.go b/internal/api/handlers/management/api_tools_cbor_test.go
new file mode 100644
index 00000000..8b7570a9
--- /dev/null
+++ b/internal/api/handlers/management/api_tools_cbor_test.go
@@ -0,0 +1,149 @@
+package management
+
+import (
+ "bytes"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/fxamacker/cbor/v2"
+ "github.com/gin-gonic/gin"
+)
+
+func TestAPICall_CBOR_Support(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ // Create a test handler
+ h := &Handler{}
+
+ // Create test request data
+ reqData := apiCallRequest{
+ Method: "GET",
+ URL: "https://httpbin.org/get",
+ Header: map[string]string{
+ "User-Agent": "test-client",
+ },
+ }
+
+ t.Run("JSON request and response", func(t *testing.T) {
+ // Marshal request as JSON
+ jsonData, err := json.Marshal(reqData)
+ if err != nil {
+ t.Fatalf("Failed to marshal JSON: %v", err)
+ }
+
+ // Create HTTP request
+ req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(jsonData))
+ req.Header.Set("Content-Type", "application/json")
+
+ // Create response recorder
+ w := httptest.NewRecorder()
+
+ // Create Gin context
+ c, _ := gin.CreateTestContext(w)
+ c.Request = req
+
+ // Call handler
+ h.APICall(c)
+
+ // Verify response
+ if w.Code != http.StatusOK && w.Code != http.StatusBadGateway {
+ t.Logf("Response status: %d", w.Code)
+ t.Logf("Response body: %s", w.Body.String())
+ }
+
+ // Check content type
+ contentType := w.Header().Get("Content-Type")
+ if w.Code == http.StatusOK && !contains(contentType, "application/json") {
+ t.Errorf("Expected JSON response, got: %s", contentType)
+ }
+ })
+
+ t.Run("CBOR request and response", func(t *testing.T) {
+ // Marshal request as CBOR
+ cborData, err := cbor.Marshal(reqData)
+ if err != nil {
+ t.Fatalf("Failed to marshal CBOR: %v", err)
+ }
+
+ // Create HTTP request
+ req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(cborData))
+ req.Header.Set("Content-Type", "application/cbor")
+
+ // Create response recorder
+ w := httptest.NewRecorder()
+
+ // Create Gin context
+ c, _ := gin.CreateTestContext(w)
+ c.Request = req
+
+ // Call handler
+ h.APICall(c)
+
+ // Verify response
+ if w.Code != http.StatusOK && w.Code != http.StatusBadGateway {
+ t.Logf("Response status: %d", w.Code)
+ t.Logf("Response body: %s", w.Body.String())
+ }
+
+ // Check content type
+ contentType := w.Header().Get("Content-Type")
+ if w.Code == http.StatusOK && !contains(contentType, "application/cbor") {
+ t.Errorf("Expected CBOR response, got: %s", contentType)
+ }
+
+ // Try to decode CBOR response
+ if w.Code == http.StatusOK {
+ var response apiCallResponse
+ if err := cbor.Unmarshal(w.Body.Bytes(), &response); err != nil {
+ t.Errorf("Failed to unmarshal CBOR response: %v", err)
+ } else {
+ t.Logf("CBOR response decoded successfully: status_code=%d", response.StatusCode)
+ }
+ }
+ })
+
+ t.Run("CBOR encoding and decoding consistency", func(t *testing.T) {
+ // Test data
+ testReq := apiCallRequest{
+ Method: "POST",
+ URL: "https://example.com/api",
+ Header: map[string]string{
+ "Authorization": "Bearer $TOKEN$",
+ "Content-Type": "application/json",
+ },
+ Data: `{"key":"value"}`,
+ }
+
+ // Encode to CBOR
+ cborData, err := cbor.Marshal(testReq)
+ if err != nil {
+ t.Fatalf("Failed to marshal to CBOR: %v", err)
+ }
+
+ // Decode from CBOR
+ var decoded apiCallRequest
+ if err := cbor.Unmarshal(cborData, &decoded); err != nil {
+ t.Fatalf("Failed to unmarshal from CBOR: %v", err)
+ }
+
+ // Verify fields
+ if decoded.Method != testReq.Method {
+ t.Errorf("Method mismatch: got %s, want %s", decoded.Method, testReq.Method)
+ }
+ if decoded.URL != testReq.URL {
+ t.Errorf("URL mismatch: got %s, want %s", decoded.URL, testReq.URL)
+ }
+ if decoded.Data != testReq.Data {
+ t.Errorf("Data mismatch: got %s, want %s", decoded.Data, testReq.Data)
+ }
+ if len(decoded.Header) != len(testReq.Header) {
+ t.Errorf("Header count mismatch: got %d, want %d", len(decoded.Header), len(testReq.Header))
+ }
+ })
+}
+
+func contains(s, substr string) bool {
+ return len(s) > 0 && len(substr) > 0 && (s == substr || len(s) >= len(substr) && s[:len(substr)] == substr || bytes.Contains([]byte(s), []byte(substr)))
+}
diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go
index 7f7fad15..bd1338a2 100644
--- a/internal/api/handlers/management/auth_files.go
+++ b/internal/api/handlers/management/auth_files.go
@@ -3,7 +3,9 @@ package management
import (
"bytes"
"context"
+ "crypto/rand"
"crypto/sha256"
+ "encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
@@ -11,6 +13,7 @@ import (
"io"
"net"
"net/http"
+ "net/url"
"os"
"path/filepath"
"sort"
@@ -23,9 +26,12 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
+ kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
@@ -1903,6 +1909,89 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
}
+func (h *Handler) RequestGitHubToken(c *gin.Context) {
+ ctx := context.Background()
+
+ fmt.Println("Initializing GitHub Copilot authentication...")
+
+ state := fmt.Sprintf("gh-%d", time.Now().UnixNano())
+
+ // Initialize Copilot auth service
+ // We need to import "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" first if not present
+ // Assuming copilot package is imported as "copilot"
+ deviceClient := copilot.NewDeviceFlowClient(h.cfg)
+
+ // Initiate device flow
+ deviceCode, err := deviceClient.RequestDeviceCode(ctx)
+ if err != nil {
+ log.Errorf("Failed to initiate device flow: %v", err)
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate device flow"})
+ return
+ }
+
+ authURL := deviceCode.VerificationURI
+ userCode := deviceCode.UserCode
+
+ RegisterOAuthSession(state, "github")
+
+ go func() {
+ fmt.Printf("Please visit %s and enter code: %s\n", authURL, userCode)
+
+ tokenData, errPoll := deviceClient.PollForToken(ctx, deviceCode)
+ if errPoll != nil {
+ SetOAuthSessionError(state, "Authentication failed")
+ fmt.Printf("Authentication failed: %v\n", errPoll)
+ return
+ }
+
+ username, errUser := deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
+ if errUser != nil {
+ log.Warnf("Failed to fetch user info: %v", errUser)
+ username = "github-user"
+ }
+
+ tokenStorage := &copilot.CopilotTokenStorage{
+ AccessToken: tokenData.AccessToken,
+ TokenType: tokenData.TokenType,
+ Scope: tokenData.Scope,
+ Username: username,
+ Type: "github-copilot",
+ }
+
+ fileName := fmt.Sprintf("github-%s.json", username)
+ record := &coreauth.Auth{
+ ID: fileName,
+ Provider: "github",
+ FileName: fileName,
+ Storage: tokenStorage,
+ Metadata: map[string]any{
+ "email": username,
+ "username": username,
+ },
+ }
+
+ savedPath, errSave := h.saveTokenRecord(ctx, record)
+ if errSave != nil {
+ log.Errorf("Failed to save authentication tokens: %v", errSave)
+ SetOAuthSessionError(state, "Failed to save authentication tokens")
+ return
+ }
+
+ fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
+ fmt.Println("You can now use GitHub Copilot services through this CLI")
+ CompleteOAuthSession(state)
+ CompleteOAuthSessionsByProvider("github")
+ }()
+
+ c.JSON(200, gin.H{
+ "status": "ok",
+ "url": authURL,
+ "state": state,
+ "user_code": userCode,
+ "verification_uri": authURL,
+ })
+}
+
func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
ctx := context.Background()
@@ -2407,8 +2496,407 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
return
}
if status != "" {
+ if strings.HasPrefix(status, "device_code|") {
+ parts := strings.SplitN(status, "|", 3)
+ if len(parts) == 3 {
+ c.JSON(http.StatusOK, gin.H{
+ "status": "device_code",
+ "verification_url": parts[1],
+ "user_code": parts[2],
+ })
+ return
+ }
+ }
+ if strings.HasPrefix(status, "auth_url|") {
+ authURL := strings.TrimPrefix(status, "auth_url|")
+ c.JSON(http.StatusOK, gin.H{
+ "status": "auth_url",
+ "url": authURL,
+ })
+ return
+ }
c.JSON(http.StatusOK, gin.H{"status": "error", "error": status})
return
}
c.JSON(http.StatusOK, gin.H{"status": "wait"})
}
+
+const kiroCallbackPort = 9876
+
+func (h *Handler) RequestKiroToken(c *gin.Context) {
+ ctx := context.Background()
+
+ // Get the login method from query parameter (default: aws for device code flow)
+ method := strings.ToLower(strings.TrimSpace(c.Query("method")))
+ if method == "" {
+ method = "aws"
+ }
+
+ fmt.Println("Initializing Kiro authentication...")
+
+ state := fmt.Sprintf("kiro-%d", time.Now().UnixNano())
+
+ switch method {
+ case "aws", "builder-id":
+ RegisterOAuthSession(state, "kiro")
+
+ // AWS Builder ID uses device code flow (no callback needed)
+ go func() {
+ ssoClient := kiroauth.NewSSOOIDCClient(h.cfg)
+
+ // Step 1: Register client
+ fmt.Println("Registering client...")
+ regResp, errRegister := ssoClient.RegisterClient(ctx)
+ if errRegister != nil {
+ log.Errorf("Failed to register client: %v", errRegister)
+ SetOAuthSessionError(state, "Failed to register client")
+ return
+ }
+
+ // Step 2: Start device authorization
+ fmt.Println("Starting device authorization...")
+ authResp, errAuth := ssoClient.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret)
+ if errAuth != nil {
+ log.Errorf("Failed to start device auth: %v", errAuth)
+ SetOAuthSessionError(state, "Failed to start device authorization")
+ return
+ }
+
+ // Store the verification URL for the frontend to display.
+ // Using "|" as separator because URLs contain ":".
+ SetOAuthSessionError(state, "device_code|"+authResp.VerificationURIComplete+"|"+authResp.UserCode)
+
+ // Step 3: Poll for token
+ fmt.Println("Waiting for authorization...")
+ interval := 5 * time.Second
+ if authResp.Interval > 0 {
+ interval = time.Duration(authResp.Interval) * time.Second
+ }
+ deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
+
+ for time.Now().Before(deadline) {
+ select {
+ case <-ctx.Done():
+ SetOAuthSessionError(state, "Authorization cancelled")
+ return
+ case <-time.After(interval):
+ tokenResp, errToken := ssoClient.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
+ if errToken != nil {
+ errStr := errToken.Error()
+ if strings.Contains(errStr, "authorization_pending") {
+ continue
+ }
+ if strings.Contains(errStr, "slow_down") {
+ interval += 5 * time.Second
+ continue
+ }
+ log.Errorf("Token creation failed: %v", errToken)
+ SetOAuthSessionError(state, "Token creation failed")
+ return
+ }
+
+ // Success! Save the token
+ expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
+ email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
+
+ idPart := kiroauth.SanitizeEmailForFilename(email)
+ if idPart == "" {
+ idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
+ }
+
+ now := time.Now()
+ fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
+
+ record := &coreauth.Auth{
+ ID: fileName,
+ Provider: "kiro",
+ FileName: fileName,
+ Metadata: map[string]any{
+ "type": "kiro",
+ "access_token": tokenResp.AccessToken,
+ "refresh_token": tokenResp.RefreshToken,
+ "expires_at": expiresAt.Format(time.RFC3339),
+ "auth_method": "builder-id",
+ "provider": "AWS",
+ "client_id": regResp.ClientID,
+ "client_secret": regResp.ClientSecret,
+ "email": email,
+ "last_refresh": now.Format(time.RFC3339),
+ },
+ }
+
+ savedPath, errSave := h.saveTokenRecord(ctx, record)
+ if errSave != nil {
+ log.Errorf("Failed to save authentication tokens: %v", errSave)
+ SetOAuthSessionError(state, "Failed to save authentication tokens")
+ return
+ }
+
+ fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
+ if email != "" {
+ fmt.Printf("Authenticated as: %s\n", email)
+ }
+ CompleteOAuthSession(state)
+ return
+ }
+ }
+
+ SetOAuthSessionError(state, "Authorization timed out")
+ }()
+
+ // Return immediately with the state for polling
+ c.JSON(http.StatusOK, gin.H{"status": "ok", "state": state, "method": "device_code"})
+
+ case "google", "github":
+ RegisterOAuthSession(state, "kiro")
+
+ // Social auth uses protocol handler - for WEB UI we use a callback forwarder
+ provider := "Google"
+ if method == "github" {
+ provider = "Github"
+ }
+
+ isWebUI := isWebUIRequest(c)
+ if isWebUI {
+ targetURL, errTarget := h.managementCallbackURL("/kiro/callback")
+ if errTarget != nil {
+ log.WithError(errTarget).Error("failed to compute kiro callback target")
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
+ return
+ }
+ if _, errStart := startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil {
+ log.WithError(errStart).Error("failed to start kiro callback forwarder")
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
+ return
+ }
+ }
+
+ go func() {
+ if isWebUI {
+ defer stopCallbackForwarder(kiroCallbackPort)
+ }
+
+ socialClient := kiroauth.NewSocialAuthClient(h.cfg)
+
+ // Generate PKCE codes
+ codeVerifier, codeChallenge, errPKCE := generateKiroPKCE()
+ if errPKCE != nil {
+ log.Errorf("Failed to generate PKCE: %v", errPKCE)
+ SetOAuthSessionError(state, "Failed to generate PKCE")
+ return
+ }
+
+ // Build login URL
+ authURL := fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account",
+ "https://prod.us-east-1.auth.desktop.kiro.dev",
+ provider,
+ url.QueryEscape(kiroauth.KiroRedirectURI),
+ codeChallenge,
+ state,
+ )
+
+ // Store auth URL for frontend.
+ // Using "|" as separator because URLs contain ":".
+ SetOAuthSessionError(state, "auth_url|"+authURL)
+
+ // Wait for callback file
+ waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-kiro-%s.oauth", state))
+ deadline := time.Now().Add(5 * time.Minute)
+
+ for {
+ if time.Now().After(deadline) {
+ log.Error("oauth flow timed out")
+ SetOAuthSessionError(state, "OAuth flow timed out")
+ return
+ }
+ if data, errRead := os.ReadFile(waitFile); errRead == nil {
+ var m map[string]string
+ _ = json.Unmarshal(data, &m)
+ _ = os.Remove(waitFile)
+ if errStr := m["error"]; errStr != "" {
+ log.Errorf("Authentication failed: %s", errStr)
+ SetOAuthSessionError(state, "Authentication failed")
+ return
+ }
+ if m["state"] != state {
+ log.Errorf("State mismatch")
+ SetOAuthSessionError(state, "State mismatch")
+ return
+ }
+ code := m["code"]
+ if code == "" {
+ log.Error("No authorization code received")
+ SetOAuthSessionError(state, "No authorization code received")
+ return
+ }
+
+ // Exchange code for tokens
+ tokenReq := &kiroauth.CreateTokenRequest{
+ Code: code,
+ CodeVerifier: codeVerifier,
+ RedirectURI: kiroauth.KiroRedirectURI,
+ }
+
+ tokenResp, errToken := socialClient.CreateToken(ctx, tokenReq)
+ if errToken != nil {
+ log.Errorf("Failed to exchange code for tokens: %v", errToken)
+ SetOAuthSessionError(state, "Failed to exchange code for tokens")
+ return
+ }
+
+ // Save the token
+ expiresIn := tokenResp.ExpiresIn
+ if expiresIn <= 0 {
+ expiresIn = 3600
+ }
+ expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
+ email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
+
+ idPart := kiroauth.SanitizeEmailForFilename(email)
+ if idPart == "" {
+ idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
+ }
+
+ now := time.Now()
+ fileName := fmt.Sprintf("kiro-%s-%s.json", strings.ToLower(provider), idPart)
+
+ record := &coreauth.Auth{
+ ID: fileName,
+ Provider: "kiro",
+ FileName: fileName,
+ Metadata: map[string]any{
+ "type": "kiro",
+ "access_token": tokenResp.AccessToken,
+ "refresh_token": tokenResp.RefreshToken,
+ "profile_arn": tokenResp.ProfileArn,
+ "expires_at": expiresAt.Format(time.RFC3339),
+ "auth_method": "social",
+ "provider": provider,
+ "email": email,
+ "last_refresh": now.Format(time.RFC3339),
+ },
+ }
+
+ savedPath, errSave := h.saveTokenRecord(ctx, record)
+ if errSave != nil {
+ log.Errorf("Failed to save authentication tokens: %v", errSave)
+ SetOAuthSessionError(state, "Failed to save authentication tokens")
+ return
+ }
+
+ fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
+ if email != "" {
+ fmt.Printf("Authenticated as: %s\n", email)
+ }
+ CompleteOAuthSession(state)
+ return
+ }
+ time.Sleep(500 * time.Millisecond)
+ }
+ }()
+
+ c.JSON(http.StatusOK, gin.H{"status": "ok", "state": state, "method": "social"})
+
+ default:
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid method, use 'aws', 'google', or 'github'"})
+ }
+}
+
+// generateKiroPKCE generates PKCE code verifier and challenge for Kiro OAuth.
+func generateKiroPKCE() (verifier, challenge string, err error) {
+ b := make([]byte, 32)
+ if _, errRead := io.ReadFull(rand.Reader, b); errRead != nil {
+ return "", "", fmt.Errorf("failed to generate random bytes: %w", errRead)
+ }
+ verifier = base64.RawURLEncoding.EncodeToString(b)
+
+ h := sha256.Sum256([]byte(verifier))
+ challenge = base64.RawURLEncoding.EncodeToString(h[:])
+
+ return verifier, challenge, nil
+}
+
+func (h *Handler) RequestKiloToken(c *gin.Context) {
+ ctx := context.Background()
+
+ fmt.Println("Initializing Kilo authentication...")
+
+ state := fmt.Sprintf("kil-%d", time.Now().UnixNano())
+ kilocodeAuth := kilo.NewKiloAuth()
+
+ resp, err := kilocodeAuth.InitiateDeviceFlow(ctx)
+ if err != nil {
+ log.Errorf("Failed to initiate device flow: %v", err)
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initiate device flow"})
+ return
+ }
+
+ RegisterOAuthSession(state, "kilo")
+
+ go func() {
+ fmt.Printf("Please visit %s and enter code: %s\n", resp.VerificationURL, resp.Code)
+
+ status, err := kilocodeAuth.PollForToken(ctx, resp.Code)
+ if err != nil {
+ SetOAuthSessionError(state, "Authentication failed")
+ fmt.Printf("Authentication failed: %v\n", err)
+ return
+ }
+
+ profile, err := kilocodeAuth.GetProfile(ctx, status.Token)
+ if err != nil {
+ log.Warnf("Failed to fetch profile: %v", err)
+ profile = &kilo.Profile{Email: status.UserEmail}
+ }
+
+ var orgID string
+ if len(profile.Orgs) > 0 {
+ orgID = profile.Orgs[0].ID
+ }
+
+ defaults, err := kilocodeAuth.GetDefaults(ctx, status.Token, orgID)
+ if err != nil {
+ defaults = &kilo.Defaults{}
+ }
+
+ ts := &kilo.KiloTokenStorage{
+ Token: status.Token,
+ OrganizationID: orgID,
+ Model: defaults.Model,
+ Email: status.UserEmail,
+ Type: "kilo",
+ }
+
+ fileName := kilo.CredentialFileName(status.UserEmail)
+ record := &coreauth.Auth{
+ ID: fileName,
+ Provider: "kilo",
+ FileName: fileName,
+ Storage: ts,
+ Metadata: map[string]any{
+ "email": status.UserEmail,
+ "organization_id": orgID,
+ "model": defaults.Model,
+ },
+ }
+
+ savedPath, errSave := h.saveTokenRecord(ctx, record)
+ if errSave != nil {
+ log.Errorf("Failed to save authentication tokens: %v", errSave)
+ SetOAuthSessionError(state, "Failed to save authentication tokens")
+ return
+ }
+
+ fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
+ CompleteOAuthSession(state)
+ CompleteOAuthSessionsByProvider("kilo")
+ }()
+
+ c.JSON(200, gin.H{
+ "status": "ok",
+ "url": resp.VerificationURL,
+ "state": state,
+ "user_code": resp.Code,
+ "verification_uri": resp.VerificationURL,
+ })
+}
diff --git a/internal/api/handlers/management/config_basic.go b/internal/api/handlers/management/config_basic.go
index f77e91e9..72f73d32 100644
--- a/internal/api/handlers/management/config_basic.go
+++ b/internal/api/handlers/management/config_basic.go
@@ -19,8 +19,8 @@ import (
)
const (
- latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPI/releases/latest"
- latestReleaseUserAgent = "CLIProxyAPI"
+ latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPIPlus/releases/latest"
+ latestReleaseUserAgent = "CLIProxyAPIPlus"
)
func (h *Handler) GetConfig(c *gin.Context) {
diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go
index 66e89992..0153a381 100644
--- a/internal/api/handlers/management/config_lists.go
+++ b/internal/api/handlers/management/config_lists.go
@@ -753,18 +753,22 @@ func (h *Handler) PatchOAuthModelAlias(c *gin.Context) {
normalizedMap := sanitizedOAuthModelAlias(map[string][]config.OAuthModelAlias{channel: body.Aliases})
normalized := normalizedMap[channel]
if len(normalized) == 0 {
+ // Only delete if channel exists, otherwise just create empty entry
+ if h.cfg.OAuthModelAlias != nil {
+ if _, ok := h.cfg.OAuthModelAlias[channel]; ok {
+ delete(h.cfg.OAuthModelAlias, channel)
+ if len(h.cfg.OAuthModelAlias) == 0 {
+ h.cfg.OAuthModelAlias = nil
+ }
+ h.persist(c)
+ return
+ }
+ }
+ // Create new channel with empty aliases
if h.cfg.OAuthModelAlias == nil {
- c.JSON(404, gin.H{"error": "channel not found"})
- return
- }
- if _, ok := h.cfg.OAuthModelAlias[channel]; !ok {
- c.JSON(404, gin.H{"error": "channel not found"})
- return
- }
- delete(h.cfg.OAuthModelAlias, channel)
- if len(h.cfg.OAuthModelAlias) == 0 {
- h.cfg.OAuthModelAlias = nil
+ h.cfg.OAuthModelAlias = make(map[string][]config.OAuthModelAlias)
}
+ h.cfg.OAuthModelAlias[channel] = []config.OAuthModelAlias{}
h.persist(c)
return
}
@@ -792,10 +796,10 @@ func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
- delete(h.cfg.OAuthModelAlias, channel)
- if len(h.cfg.OAuthModelAlias) == 0 {
- h.cfg.OAuthModelAlias = nil
- }
+ // Set to nil instead of deleting the key so that the "explicitly disabled"
+ // marker survives config reload and prevents SanitizeOAuthModelAlias from
+ // re-injecting default aliases (fixes #222).
+ h.cfg.OAuthModelAlias[channel] = nil
h.persist(c)
}
diff --git a/internal/api/handlers/management/oauth_sessions.go b/internal/api/handlers/management/oauth_sessions.go
index 05ff8d1f..bc882e99 100644
--- a/internal/api/handlers/management/oauth_sessions.go
+++ b/internal/api/handlers/management/oauth_sessions.go
@@ -158,7 +158,12 @@ func (s *oauthSessionStore) IsPending(state, provider string) bool {
return false
}
if session.Status != "" {
- return false
+ if !strings.EqualFold(session.Provider, "kiro") {
+ return false
+ }
+ if !strings.HasPrefix(session.Status, "device_code|") && !strings.HasPrefix(session.Status, "auth_url|") {
+ return false
+ }
}
if provider == "" {
return true
@@ -231,6 +236,10 @@ func NormalizeOAuthProvider(provider string) (string, error) {
return "antigravity", nil
case "qwen":
return "qwen", nil
+ case "kiro":
+ return "kiro", nil
+ case "github":
+ return "github", nil
default:
return "", errUnsupportedOAuthFlow
}
diff --git a/internal/api/modules/amp/proxy.go b/internal/api/modules/amp/proxy.go
index c460a0d6..211f0f5d 100644
--- a/internal/api/modules/amp/proxy.go
+++ b/internal/api/modules/amp/proxy.go
@@ -3,8 +3,11 @@ package amp
import (
"bytes"
"compress/gzip"
+ "context"
+ "errors"
"fmt"
"io"
+ "net"
"net/http"
"net/http/httputil"
"net/url"
@@ -102,7 +105,15 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
// Modify incoming responses to handle gzip without Content-Encoding
// This addresses the same issue as inline handler gzip handling, but at the proxy level
proxy.ModifyResponse = func(resp *http.Response) error {
- // Only process successful responses
+ // Log upstream error responses for diagnostics (502, 503, etc.)
+ // These are NOT proxy connection errors - the upstream responded with an error status
+ if resp.StatusCode >= 500 {
+ log.Errorf("amp upstream responded with error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
+ } else if resp.StatusCode >= 400 {
+ log.Warnf("amp upstream responded with client error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
+ }
+
+ // Only process successful responses for gzip decompression
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil
}
@@ -186,9 +197,29 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
return nil
}
- // Error handler for proxy failures
+ // Error handler for proxy failures with detailed error classification for diagnostics
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
- log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err)
+ // Classify the error type for better diagnostics
+ var errType string
+ if errors.Is(err, context.DeadlineExceeded) {
+ errType = "timeout"
+ } else if errors.Is(err, context.Canceled) {
+ errType = "canceled"
+ } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
+ errType = "dial_timeout"
+ } else if _, ok := err.(net.Error); ok {
+ errType = "network_error"
+ } else {
+ errType = "connection_error"
+ }
+
+ // Don't log as error for context canceled - it's usually client closing connection
+ if errors.Is(err, context.Canceled) {
+ log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path)
+ } else {
+ log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
+ }
+
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusBadGateway)
_, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`))
diff --git a/internal/api/modules/amp/response_rewriter.go b/internal/api/modules/amp/response_rewriter.go
index 715034f1..8a9cad70 100644
--- a/internal/api/modules/amp/response_rewriter.go
+++ b/internal/api/modules/amp/response_rewriter.go
@@ -29,15 +29,71 @@ func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRe
}
}
+const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
+
+func looksLikeSSEChunk(data []byte) bool {
+ // Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
+ // Heuristics are intentionally simple and cheap.
+ return bytes.Contains(data, []byte("data:")) ||
+ bytes.Contains(data, []byte("event:")) ||
+ bytes.Contains(data, []byte("message_start")) ||
+ bytes.Contains(data, []byte("message_delta")) ||
+ bytes.Contains(data, []byte("content_block_start")) ||
+ bytes.Contains(data, []byte("content_block_delta")) ||
+ bytes.Contains(data, []byte("content_block_stop")) ||
+ bytes.Contains(data, []byte("\n\n"))
+}
+
+func (rw *ResponseRewriter) enableStreaming(reason string) error {
+ if rw.isStreaming {
+ return nil
+ }
+ rw.isStreaming = true
+
+ // Flush any previously buffered data to avoid reordering or data loss.
+ if rw.body != nil && rw.body.Len() > 0 {
+ buf := rw.body.Bytes()
+ // Copy before Reset() to keep bytes stable.
+ toFlush := make([]byte, len(buf))
+ copy(toFlush, buf)
+ rw.body.Reset()
+
+ if _, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(toFlush)); err != nil {
+ return err
+ }
+ if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
+ flusher.Flush()
+ }
+ }
+
+ log.Debugf("amp response rewriter: switched to streaming (%s)", reason)
+ return nil
+}
+
// Write intercepts response writes and buffers them for model name replacement
func (rw *ResponseRewriter) Write(data []byte) (int, error) {
- // Detect streaming on first write
- if rw.body.Len() == 0 && !rw.isStreaming {
+ // Detect streaming on first write (header-based)
+ if !rw.isStreaming && rw.body.Len() == 0 {
contentType := rw.Header().Get("Content-Type")
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
strings.Contains(contentType, "stream")
}
+ if !rw.isStreaming {
+ // Content-based fallback: detect SSE-like chunks even if Content-Type is missing/wrong.
+ if looksLikeSSEChunk(data) {
+ if err := rw.enableStreaming("sse heuristic"); err != nil {
+ return 0, err
+ }
+ } else if rw.body.Len()+len(data) > maxBufferedResponseBytes {
+ // Safety cap: avoid unbounded buffering on large responses.
+ log.Warnf("amp response rewriter: buffer exceeded %d bytes, switching to streaming", maxBufferedResponseBytes)
+ if err := rw.enableStreaming("buffer limit"); err != nil {
+ return 0, err
+ }
+ }
+ }
+
if rw.isStreaming {
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
if err == nil {
diff --git a/internal/api/server.go b/internal/api/server.go
index 76e9a33a..98041b8b 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -24,6 +24,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
@@ -292,6 +293,11 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
s.registerManagementRoutes()
}
+ // === CLIProxyAPIPlus 扩展: 注册 Kiro OAuth Web 路由 ===
+ kiroOAuthHandler := kiro.NewOAuthWebHandler(cfg)
+ kiroOAuthHandler.RegisterRoutes(engine)
+ log.Info("Kiro OAuth Web routes registered at /v0/oauth/kiro/*")
+
if optionState.keepAliveEnabled {
s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout)
}
@@ -349,6 +355,12 @@ func (s *Server) setupRoutes() {
},
})
})
+
+ // Event logging endpoint - handles Claude Code telemetry requests
+ // Returns 200 OK to prevent 404 errors in logs
+ s.engine.POST("/api/event_logging/batch", func(c *gin.Context) {
+ c.JSON(http.StatusOK, gin.H{"status": "ok"})
+ })
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
// OAuth callback endpoints (reuse main server port)
@@ -424,6 +436,20 @@ func (s *Server) setupRoutes() {
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
+ s.engine.GET("/kiro/callback", func(c *gin.Context) {
+ code := c.Query("code")
+ state := c.Query("state")
+ errStr := c.Query("error")
+ if errStr == "" {
+ errStr = c.Query("error_description")
+ }
+ if state != "" {
+ _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "kiro", state, code, errStr)
+ }
+ c.Header("Content-Type", "text/html; charset=utf-8")
+ c.String(http.StatusOK, oauthCallbackSuccessHTML)
+ })
+
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
}
@@ -626,9 +652,12 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
+ mgmt.GET("/kilo-auth-url", s.mgmt.RequestKiloToken)
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
+ mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
+ mgmt.GET("/github-auth-url", s.mgmt.RequestGitHubToken)
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
}
diff --git a/internal/auth/claude/oauth_server.go b/internal/auth/claude/oauth_server.go
index a6ebe2f7..49b04794 100644
--- a/internal/auth/claude/oauth_server.go
+++ b/internal/auth/claude/oauth_server.go
@@ -242,6 +242,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
platformURL = "https://console.anthropic.com/"
}
+ // Validate platformURL to prevent XSS - only allow http/https URLs
+ if !isValidURL(platformURL) {
+ platformURL = "https://console.anthropic.com/"
+ }
+
// Generate success page HTML with dynamic content
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
@@ -251,6 +256,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
}
}
+// isValidURL checks if the URL is a valid http/https URL to prevent XSS
+func isValidURL(urlStr string) bool {
+ urlStr = strings.TrimSpace(urlStr)
+ return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://")
+}
+
// generateSuccessHTML creates the HTML content for the success page.
// It customizes the page based on whether additional setup is required
// and includes a link to the platform.
diff --git a/internal/auth/codex/oauth_server.go b/internal/auth/codex/oauth_server.go
index 9c6a6c5b..58b5394e 100644
--- a/internal/auth/codex/oauth_server.go
+++ b/internal/auth/codex/oauth_server.go
@@ -239,6 +239,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
platformURL = "https://platform.openai.com"
}
+ // Validate platformURL to prevent XSS - only allow http/https URLs
+ if !isValidURL(platformURL) {
+ platformURL = "https://platform.openai.com"
+ }
+
// Generate success page HTML with dynamic content
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
@@ -248,6 +253,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
}
}
+// isValidURL checks if the URL is a valid http/https URL to prevent XSS
+func isValidURL(urlStr string) bool {
+ urlStr = strings.TrimSpace(urlStr)
+ return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://")
+}
+
// generateSuccessHTML creates the HTML content for the success page.
// It customizes the page based on whether additional setup is required
// and includes a link to the platform.
diff --git a/internal/auth/copilot/copilot_auth.go b/internal/auth/copilot/copilot_auth.go
new file mode 100644
index 00000000..c40e7082
--- /dev/null
+++ b/internal/auth/copilot/copilot_auth.go
@@ -0,0 +1,225 @@
+// Package copilot provides authentication and token management for GitHub Copilot API.
+// It handles the OAuth2 device flow for secure authentication with the Copilot API.
+package copilot
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "time"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ // copilotAPITokenURL is the endpoint for getting Copilot API tokens from GitHub token.
+ copilotAPITokenURL = "https://api.github.com/copilot_internal/v2/token"
+ // copilotAPIEndpoint is the base URL for making API requests.
+ copilotAPIEndpoint = "https://api.githubcopilot.com"
+
+ // Common HTTP header values for Copilot API requests.
+ copilotUserAgent = "GithubCopilot/1.0"
+ copilotEditorVersion = "vscode/1.100.0"
+ copilotPluginVersion = "copilot/1.300.0"
+ copilotIntegrationID = "vscode-chat"
+ copilotOpenAIIntent = "conversation-panel"
+)
+
+// CopilotAPIToken represents the Copilot API token response.
+type CopilotAPIToken struct {
+ // Token is the JWT token for authenticating with the Copilot API.
+ Token string `json:"token"`
+ // ExpiresAt is the Unix timestamp when the token expires.
+ ExpiresAt int64 `json:"expires_at"`
+ // Endpoints contains the available API endpoints.
+ Endpoints struct {
+ API string `json:"api"`
+ Proxy string `json:"proxy"`
+ OriginTracker string `json:"origin-tracker"`
+ Telemetry string `json:"telemetry"`
+ } `json:"endpoints,omitempty"`
+ // ErrorDetails contains error information if the request failed.
+ ErrorDetails *struct {
+ URL string `json:"url"`
+ Message string `json:"message"`
+ DocumentationURL string `json:"documentation_url"`
+ } `json:"error_details,omitempty"`
+}
+
+// CopilotAuth handles GitHub Copilot authentication flow.
+// It provides methods for device flow authentication and token management.
+type CopilotAuth struct {
+ httpClient *http.Client
+ deviceClient *DeviceFlowClient
+ cfg *config.Config
+}
+
+// NewCopilotAuth creates a new CopilotAuth service instance.
+// It initializes an HTTP client with proxy settings from the provided configuration.
+func NewCopilotAuth(cfg *config.Config) *CopilotAuth {
+ return &CopilotAuth{
+ httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
+ deviceClient: NewDeviceFlowClient(cfg),
+ cfg: cfg,
+ }
+}
+
+// StartDeviceFlow initiates the device flow authentication.
+// Returns the device code response containing the user code and verification URI.
+func (c *CopilotAuth) StartDeviceFlow(ctx context.Context) (*DeviceCodeResponse, error) {
+ return c.deviceClient.RequestDeviceCode(ctx)
+}
+
+// WaitForAuthorization polls for user authorization and returns the auth bundle.
+func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *DeviceCodeResponse) (*CopilotAuthBundle, error) {
+ tokenData, err := c.deviceClient.PollForToken(ctx, deviceCode)
+ if err != nil {
+ return nil, err
+ }
+
+ // Fetch the GitHub username
+ username, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
+ if err != nil {
+ log.Warnf("copilot: failed to fetch user info: %v", err)
+ username = "unknown"
+ }
+
+ return &CopilotAuthBundle{
+ TokenData: tokenData,
+ Username: username,
+ }, nil
+}
+
+// GetCopilotAPIToken exchanges a GitHub access token for a Copilot API token.
+// This token is used to make authenticated requests to the Copilot API.
+func (c *CopilotAuth) GetCopilotAPIToken(ctx context.Context, githubAccessToken string) (*CopilotAPIToken, error) {
+ if githubAccessToken == "" {
+ return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("github access token is empty"))
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotAPITokenURL, nil)
+ if err != nil {
+ return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
+ }
+
+ req.Header.Set("Authorization", "token "+githubAccessToken)
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("User-Agent", copilotUserAgent)
+ req.Header.Set("Editor-Version", copilotEditorVersion)
+ req.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
+ }
+ defer func() {
+ if errClose := resp.Body.Close(); errClose != nil {
+ log.Errorf("copilot api token: close body error: %v", errClose)
+ }
+ }()
+
+ bodyBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
+ }
+
+ if !isHTTPSuccess(resp.StatusCode) {
+ return nil, NewAuthenticationError(ErrTokenExchangeFailed,
+ fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
+ }
+
+ var apiToken CopilotAPIToken
+ if err = json.Unmarshal(bodyBytes, &apiToken); err != nil {
+ return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
+ }
+
+ if apiToken.Token == "" {
+ return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("empty copilot api token"))
+ }
+
+ return &apiToken, nil
+}
+
+// ValidateToken checks if a GitHub access token is valid by attempting to fetch user info.
+func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bool, string, error) {
+ if accessToken == "" {
+ return false, "", nil
+ }
+
+ username, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
+ if err != nil {
+ return false, "", err
+ }
+
+ return true, username, nil
+}
+
+// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
+func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotTokenStorage {
+ return &CopilotTokenStorage{
+ AccessToken: bundle.TokenData.AccessToken,
+ TokenType: bundle.TokenData.TokenType,
+ Scope: bundle.TokenData.Scope,
+ Username: bundle.Username,
+ Type: "github-copilot",
+ }
+}
+
+// LoadAndValidateToken loads a token from storage and validates it.
+// Returns the storage if valid, or an error if the token is invalid or expired.
+func (c *CopilotAuth) LoadAndValidateToken(ctx context.Context, storage *CopilotTokenStorage) (bool, error) {
+ if storage == nil || storage.AccessToken == "" {
+ return false, fmt.Errorf("no token available")
+ }
+
+ // Check if we can still use the GitHub token to get a Copilot API token
+ apiToken, err := c.GetCopilotAPIToken(ctx, storage.AccessToken)
+ if err != nil {
+ return false, err
+ }
+
+ // Check if the API token is expired
+ if apiToken.ExpiresAt > 0 && time.Now().Unix() >= apiToken.ExpiresAt {
+ return false, fmt.Errorf("copilot api token expired")
+ }
+
+ return true, nil
+}
+
+// GetAPIEndpoint returns the Copilot API endpoint URL.
+func (c *CopilotAuth) GetAPIEndpoint() string {
+ return copilotAPIEndpoint
+}
+
+// MakeAuthenticatedRequest creates an authenticated HTTP request to the Copilot API.
+func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url string, body io.Reader, apiToken *CopilotAPIToken) (*http.Request, error) {
+ req, err := http.NewRequestWithContext(ctx, method, url, body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ req.Header.Set("Authorization", "Bearer "+apiToken.Token)
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("User-Agent", copilotUserAgent)
+ req.Header.Set("Editor-Version", copilotEditorVersion)
+ req.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
+ req.Header.Set("Openai-Intent", copilotOpenAIIntent)
+ req.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
+
+ return req, nil
+}
+
+// buildChatCompletionURL builds the URL for chat completions API.
+func buildChatCompletionURL() string {
+ return copilotAPIEndpoint + "/chat/completions"
+}
+
+// isHTTPSuccess checks if the status code indicates success (2xx).
+func isHTTPSuccess(statusCode int) bool {
+ return statusCode >= 200 && statusCode < 300
+}
diff --git a/internal/auth/copilot/errors.go b/internal/auth/copilot/errors.go
new file mode 100644
index 00000000..a82dd8ec
--- /dev/null
+++ b/internal/auth/copilot/errors.go
@@ -0,0 +1,187 @@
+package copilot
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+)
+
+// OAuthError represents an OAuth-specific error.
+type OAuthError struct {
+ // Code is the OAuth error code.
+ Code string `json:"error"`
+ // Description is a human-readable description of the error.
+ Description string `json:"error_description,omitempty"`
+ // URI is a URI identifying a human-readable web page with information about the error.
+ URI string `json:"error_uri,omitempty"`
+ // StatusCode is the HTTP status code associated with the error.
+ StatusCode int `json:"-"`
+}
+
+// Error returns a string representation of the OAuth error.
+func (e *OAuthError) Error() string {
+ if e.Description != "" {
+ return fmt.Sprintf("OAuth error %s: %s", e.Code, e.Description)
+ }
+ return fmt.Sprintf("OAuth error: %s", e.Code)
+}
+
+// NewOAuthError creates a new OAuth error with the specified code, description, and status code.
+func NewOAuthError(code, description string, statusCode int) *OAuthError {
+ return &OAuthError{
+ Code: code,
+ Description: description,
+ StatusCode: statusCode,
+ }
+}
+
+// AuthenticationError represents authentication-related errors.
+type AuthenticationError struct {
+ // Type is the type of authentication error.
+ Type string `json:"type"`
+ // Message is a human-readable message describing the error.
+ Message string `json:"message"`
+ // Code is the HTTP status code associated with the error.
+ Code int `json:"code"`
+ // Cause is the underlying error that caused this authentication error.
+ Cause error `json:"-"`
+}
+
+// Error returns a string representation of the authentication error.
+func (e *AuthenticationError) Error() string {
+ if e.Cause != nil {
+ return fmt.Sprintf("%s: %s (caused by: %v)", e.Type, e.Message, e.Cause)
+ }
+ return fmt.Sprintf("%s: %s", e.Type, e.Message)
+}
+
+// Unwrap returns the underlying cause of the error.
+func (e *AuthenticationError) Unwrap() error {
+ return e.Cause
+}
+
+// Common authentication error types for GitHub Copilot device flow.
+var (
+ // ErrDeviceCodeFailed represents an error when requesting the device code fails.
+ ErrDeviceCodeFailed = &AuthenticationError{
+ Type: "device_code_failed",
+ Message: "Failed to request device code from GitHub",
+ Code: http.StatusBadRequest,
+ }
+
+ // ErrDeviceCodeExpired represents an error when the device code has expired.
+ ErrDeviceCodeExpired = &AuthenticationError{
+ Type: "device_code_expired",
+ Message: "Device code has expired. Please try again.",
+ Code: http.StatusGone,
+ }
+
+ // ErrAuthorizationPending represents a pending authorization state (not an error, used for polling).
+ ErrAuthorizationPending = &AuthenticationError{
+ Type: "authorization_pending",
+ Message: "Authorization is pending. Waiting for user to authorize.",
+ Code: http.StatusAccepted,
+ }
+
+ // ErrSlowDown represents a request to slow down polling.
+ ErrSlowDown = &AuthenticationError{
+ Type: "slow_down",
+ Message: "Polling too frequently. Slowing down.",
+ Code: http.StatusTooManyRequests,
+ }
+
+ // ErrAccessDenied represents an error when the user denies authorization.
+ ErrAccessDenied = &AuthenticationError{
+ Type: "access_denied",
+ Message: "User denied authorization",
+ Code: http.StatusForbidden,
+ }
+
+ // ErrTokenExchangeFailed represents an error when token exchange fails.
+ ErrTokenExchangeFailed = &AuthenticationError{
+ Type: "token_exchange_failed",
+ Message: "Failed to exchange device code for access token",
+ Code: http.StatusBadRequest,
+ }
+
+ // ErrPollingTimeout represents an error when polling times out.
+ ErrPollingTimeout = &AuthenticationError{
+ Type: "polling_timeout",
+ Message: "Timeout waiting for user authorization",
+ Code: http.StatusRequestTimeout,
+ }
+
+ // ErrUserInfoFailed represents an error when fetching user info fails.
+ ErrUserInfoFailed = &AuthenticationError{
+ Type: "user_info_failed",
+ Message: "Failed to fetch GitHub user information",
+ Code: http.StatusBadRequest,
+ }
+)
+
+// NewAuthenticationError creates a new authentication error with a cause based on a base error.
+func NewAuthenticationError(baseErr *AuthenticationError, cause error) *AuthenticationError {
+ return &AuthenticationError{
+ Type: baseErr.Type,
+ Message: baseErr.Message,
+ Code: baseErr.Code,
+ Cause: cause,
+ }
+}
+
+// IsAuthenticationError checks if an error is an authentication error.
+func IsAuthenticationError(err error) bool {
+ var authenticationError *AuthenticationError
+ ok := errors.As(err, &authenticationError)
+ return ok
+}
+
+// IsOAuthError checks if an error is an OAuth error.
+func IsOAuthError(err error) bool {
+ var oAuthError *OAuthError
+ ok := errors.As(err, &oAuthError)
+ return ok
+}
+
+// GetUserFriendlyMessage returns a user-friendly error message based on the error type.
+func GetUserFriendlyMessage(err error) string {
+ var authErr *AuthenticationError
+ if errors.As(err, &authErr) {
+ switch authErr.Type {
+ case "device_code_failed":
+ return "Failed to start GitHub authentication. Please check your network connection and try again."
+ case "device_code_expired":
+ return "The authentication code has expired. Please try again."
+ case "authorization_pending":
+ return "Waiting for you to authorize the application on GitHub."
+ case "slow_down":
+ return "Please wait a moment before trying again."
+ case "access_denied":
+ return "Authentication was cancelled or denied."
+ case "token_exchange_failed":
+ return "Failed to complete authentication. Please try again."
+ case "polling_timeout":
+ return "Authentication timed out. Please try again."
+ case "user_info_failed":
+ return "Failed to get your GitHub account information. Please try again."
+ default:
+ return "Authentication failed. Please try again."
+ }
+ }
+
+ var oauthErr *OAuthError
+ if errors.As(err, &oauthErr) {
+ switch oauthErr.Code {
+ case "access_denied":
+ return "Authentication was cancelled or denied."
+ case "invalid_request":
+ return "Invalid authentication request. Please try again."
+ case "server_error":
+ return "GitHub server error. Please try again later."
+ default:
+ return fmt.Sprintf("Authentication failed: %s", oauthErr.Description)
+ }
+ }
+
+ return "An unexpected error occurred. Please try again."
+}
diff --git a/internal/auth/copilot/oauth.go b/internal/auth/copilot/oauth.go
new file mode 100644
index 00000000..d3f46aaa
--- /dev/null
+++ b/internal/auth/copilot/oauth.go
@@ -0,0 +1,255 @@
+package copilot
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ // copilotClientID is GitHub's Copilot CLI OAuth client ID.
+ copilotClientID = "Iv1.b507a08c87ecfe98"
+ // copilotDeviceCodeURL is the endpoint for requesting device codes.
+ copilotDeviceCodeURL = "https://github.com/login/device/code"
+ // copilotTokenURL is the endpoint for exchanging device codes for tokens.
+ copilotTokenURL = "https://github.com/login/oauth/access_token"
+ // copilotUserInfoURL is the endpoint for fetching GitHub user information.
+ copilotUserInfoURL = "https://api.github.com/user"
+ // defaultPollInterval is the default interval for polling token endpoint.
+ defaultPollInterval = 5 * time.Second
+ // maxPollDuration is the maximum time to wait for user authorization.
+ maxPollDuration = 15 * time.Minute
+)
+
+// DeviceFlowClient handles the OAuth2 device flow for GitHub Copilot.
+type DeviceFlowClient struct {
+ httpClient *http.Client
+ cfg *config.Config
+}
+
+// NewDeviceFlowClient creates a new device flow client.
+func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
+ client := &http.Client{Timeout: 30 * time.Second}
+ if cfg != nil {
+ client = util.SetProxy(&cfg.SDKConfig, client)
+ }
+ return &DeviceFlowClient{
+ httpClient: client,
+ cfg: cfg,
+ }
+}
+
+// RequestDeviceCode initiates the device flow by requesting a device code from GitHub.
+func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
+ data := url.Values{}
+ data.Set("client_id", copilotClientID)
+ data.Set("scope", "user:email")
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
+ if err != nil {
+ return nil, NewAuthenticationError(ErrDeviceCodeFailed, err)
+ }
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Set("Accept", "application/json")
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, NewAuthenticationError(ErrDeviceCodeFailed, err)
+ }
+ defer func() {
+ if errClose := resp.Body.Close(); errClose != nil {
+ log.Errorf("copilot device code: close body error: %v", errClose)
+ }
+ }()
+
+ if !isHTTPSuccess(resp.StatusCode) {
+ bodyBytes, _ := io.ReadAll(resp.Body)
+ return nil, NewAuthenticationError(ErrDeviceCodeFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
+ }
+
+ var deviceCode DeviceCodeResponse
+ if err = json.NewDecoder(resp.Body).Decode(&deviceCode); err != nil {
+ return nil, NewAuthenticationError(ErrDeviceCodeFailed, err)
+ }
+
+ return &deviceCode, nil
+}
+
+// PollForToken polls the token endpoint until the user authorizes or the device code expires.
+func (c *DeviceFlowClient) PollForToken(ctx context.Context, deviceCode *DeviceCodeResponse) (*CopilotTokenData, error) {
+ if deviceCode == nil {
+ return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("device code is nil"))
+ }
+
+ interval := time.Duration(deviceCode.Interval) * time.Second
+ if interval < defaultPollInterval {
+ interval = defaultPollInterval
+ }
+
+ deadline := time.Now().Add(maxPollDuration)
+ if deviceCode.ExpiresIn > 0 {
+ codeDeadline := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second)
+ if codeDeadline.Before(deadline) {
+ deadline = codeDeadline
+ }
+ }
+
+ ticker := time.NewTicker(interval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return nil, NewAuthenticationError(ErrPollingTimeout, ctx.Err())
+ case <-ticker.C:
+ if time.Now().After(deadline) {
+ return nil, ErrPollingTimeout
+ }
+
+ token, err := c.exchangeDeviceCode(ctx, deviceCode.DeviceCode)
+ if err != nil {
+ var authErr *AuthenticationError
+ if errors.As(err, &authErr) {
+ switch authErr.Type {
+ case ErrAuthorizationPending.Type:
+ // Continue polling
+ continue
+ case ErrSlowDown.Type:
+ // Increase interval and continue
+ interval += 5 * time.Second
+ ticker.Reset(interval)
+ continue
+ case ErrDeviceCodeExpired.Type:
+ return nil, err
+ case ErrAccessDenied.Type:
+ return nil, err
+ }
+ }
+ return nil, err
+ }
+ return token, nil
+ }
+ }
+}
+
+// exchangeDeviceCode attempts to exchange the device code for an access token.
+func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode string) (*CopilotTokenData, error) {
+ data := url.Values{}
+ data.Set("client_id", copilotClientID)
+ data.Set("device_code", deviceCode)
+ data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotTokenURL, strings.NewReader(data.Encode()))
+ if err != nil {
+ return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
+ }
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Set("Accept", "application/json")
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
+ }
+ defer func() {
+ if errClose := resp.Body.Close(); errClose != nil {
+ log.Errorf("copilot token exchange: close body error: %v", errClose)
+ }
+ }()
+
+ bodyBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
+ }
+
+ // GitHub returns 200 for both success and error cases in device flow
+ // Check for OAuth error response first
+ var oauthResp struct {
+ Error string `json:"error"`
+ ErrorDescription string `json:"error_description"`
+ AccessToken string `json:"access_token"`
+ TokenType string `json:"token_type"`
+ Scope string `json:"scope"`
+ }
+
+ if err = json.Unmarshal(bodyBytes, &oauthResp); err != nil {
+ return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
+ }
+
+ if oauthResp.Error != "" {
+ switch oauthResp.Error {
+ case "authorization_pending":
+ return nil, ErrAuthorizationPending
+ case "slow_down":
+ return nil, ErrSlowDown
+ case "expired_token":
+ return nil, ErrDeviceCodeExpired
+ case "access_denied":
+ return nil, ErrAccessDenied
+ default:
+ return nil, NewOAuthError(oauthResp.Error, oauthResp.ErrorDescription, resp.StatusCode)
+ }
+ }
+
+ if oauthResp.AccessToken == "" {
+ return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("empty access token"))
+ }
+
+ return &CopilotTokenData{
+ AccessToken: oauthResp.AccessToken,
+ TokenType: oauthResp.TokenType,
+ Scope: oauthResp.Scope,
+ }, nil
+}
+
+// FetchUserInfo retrieves the GitHub username for the authenticated user.
+func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
+ if accessToken == "" {
+ return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
+ if err != nil {
+ return "", NewAuthenticationError(ErrUserInfoFailed, err)
+ }
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("User-Agent", "CLIProxyAPI")
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return "", NewAuthenticationError(ErrUserInfoFailed, err)
+ }
+ defer func() {
+ if errClose := resp.Body.Close(); errClose != nil {
+ log.Errorf("copilot user info: close body error: %v", errClose)
+ }
+ }()
+
+ if !isHTTPSuccess(resp.StatusCode) {
+ bodyBytes, _ := io.ReadAll(resp.Body)
+ return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
+ }
+
+ var userInfo struct {
+ Login string `json:"login"`
+ }
+ if err = json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
+ return "", NewAuthenticationError(ErrUserInfoFailed, err)
+ }
+
+ if userInfo.Login == "" {
+ return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
+ }
+
+ return userInfo.Login, nil
+}
diff --git a/internal/auth/copilot/token.go b/internal/auth/copilot/token.go
new file mode 100644
index 00000000..4e5eed6c
--- /dev/null
+++ b/internal/auth/copilot/token.go
@@ -0,0 +1,93 @@
+// Package copilot provides authentication and token management functionality
+// for GitHub Copilot AI services. It handles OAuth2 device flow token storage,
+// serialization, and retrieval for maintaining authenticated sessions with the Copilot API.
+package copilot
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
+)
+
+// CopilotTokenStorage stores OAuth2 token information for GitHub Copilot API authentication.
+// It maintains compatibility with the existing auth system while adding Copilot-specific fields
+// for managing access tokens and user account information.
+type CopilotTokenStorage struct {
+ // AccessToken is the OAuth2 access token used for authenticating API requests.
+ AccessToken string `json:"access_token"`
+ // TokenType is the type of token, typically "bearer".
+ TokenType string `json:"token_type"`
+ // Scope is the OAuth2 scope granted to the token.
+ Scope string `json:"scope"`
+ // ExpiresAt is the timestamp when the access token expires (if provided).
+ ExpiresAt string `json:"expires_at,omitempty"`
+ // Username is the GitHub username associated with this token.
+ Username string `json:"username"`
+ // Type indicates the authentication provider type, always "github-copilot" for this storage.
+ Type string `json:"type"`
+}
+
+// CopilotTokenData holds the raw OAuth token response from GitHub.
+type CopilotTokenData struct {
+ // AccessToken is the OAuth2 access token.
+ AccessToken string `json:"access_token"`
+ // TokenType is the type of token, typically "bearer".
+ TokenType string `json:"token_type"`
+ // Scope is the OAuth2 scope granted to the token.
+ Scope string `json:"scope"`
+}
+
+// CopilotAuthBundle bundles authentication data for storage.
+type CopilotAuthBundle struct {
+ // TokenData contains the OAuth token information.
+ TokenData *CopilotTokenData
+ // Username is the GitHub username.
+ Username string
+}
+
+// DeviceCodeResponse represents GitHub's device code response.
+type DeviceCodeResponse struct {
+ // DeviceCode is the device verification code.
+ DeviceCode string `json:"device_code"`
+ // UserCode is the code the user must enter at the verification URI.
+ UserCode string `json:"user_code"`
+ // VerificationURI is the URL where the user should enter the code.
+ VerificationURI string `json:"verification_uri"`
+ // ExpiresIn is the number of seconds until the device code expires.
+ ExpiresIn int `json:"expires_in"`
+ // Interval is the minimum number of seconds to wait between polling requests.
+ Interval int `json:"interval"`
+}
+
+// SaveTokenToFile serializes the Copilot token storage to a JSON file.
+// This method creates the necessary directory structure and writes the token
+// data in JSON format to the specified file path for persistent storage.
+//
+// Parameters:
+// - authFilePath: The full path where the token file should be saved
+//
+// Returns:
+// - error: An error if the operation fails, nil otherwise
+func (ts *CopilotTokenStorage) SaveTokenToFile(authFilePath string) error {
+ misc.LogSavingCredentials(authFilePath)
+ ts.Type = "github-copilot"
+ if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
+ return fmt.Errorf("failed to create directory: %v", err)
+ }
+
+ f, err := os.Create(authFilePath)
+ if err != nil {
+ return fmt.Errorf("failed to create token file: %w", err)
+ }
+ defer func() {
+ _ = f.Close()
+ }()
+
+ if err = json.NewEncoder(f).Encode(ts); err != nil {
+ return fmt.Errorf("failed to write token to file: %w", err)
+ }
+ return nil
+}
diff --git a/internal/auth/iflow/iflow_auth.go b/internal/auth/iflow/iflow_auth.go
index fa9f38c3..279d7339 100644
--- a/internal/auth/iflow/iflow_auth.go
+++ b/internal/auth/iflow/iflow_auth.go
@@ -9,6 +9,7 @@ import (
"io"
"net/http"
"net/url"
+ "os"
"strings"
"time"
@@ -28,10 +29,21 @@ const (
iFlowAPIKeyEndpoint = "https://platform.iflow.cn/api/openapi/apikey"
// Client credentials provided by iFlow for the Code Assist integration.
- iFlowOAuthClientID = "10009311001"
- iFlowOAuthClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
+ iFlowOAuthClientID = "10009311001"
+ // Default client secret (can be overridden via IFLOW_CLIENT_SECRET env var)
+ defaultIFlowClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
)
+// getIFlowClientSecret returns the iFlow OAuth client secret.
+// It first checks the IFLOW_CLIENT_SECRET environment variable,
+// falling back to the default value if not set.
+func getIFlowClientSecret() string {
+ if secret := os.Getenv("IFLOW_CLIENT_SECRET"); secret != "" {
+ return secret
+ }
+ return defaultIFlowClientSecret
+}
+
// DefaultAPIBaseURL is the canonical chat completions endpoint.
const DefaultAPIBaseURL = "https://apis.iflow.cn/v1"
@@ -72,7 +84,7 @@ func (ia *IFlowAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectUR
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
form.Set("client_id", iFlowOAuthClientID)
- form.Set("client_secret", iFlowOAuthClientSecret)
+ form.Set("client_secret", getIFlowClientSecret())
req, err := ia.newTokenRequest(ctx, form)
if err != nil {
@@ -88,7 +100,7 @@ func (ia *IFlowAuth) RefreshTokens(ctx context.Context, refreshToken string) (*I
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", refreshToken)
form.Set("client_id", iFlowOAuthClientID)
- form.Set("client_secret", iFlowOAuthClientSecret)
+ form.Set("client_secret", getIFlowClientSecret())
req, err := ia.newTokenRequest(ctx, form)
if err != nil {
@@ -104,7 +116,7 @@ func (ia *IFlowAuth) newTokenRequest(ctx context.Context, form url.Values) (*htt
return nil, fmt.Errorf("iflow token: create request failed: %w", err)
}
- basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + iFlowOAuthClientSecret))
+ basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + getIFlowClientSecret()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Basic "+basic)
diff --git a/internal/auth/kilo/kilo_auth.go b/internal/auth/kilo/kilo_auth.go
new file mode 100644
index 00000000..dc128bf2
--- /dev/null
+++ b/internal/auth/kilo/kilo_auth.go
@@ -0,0 +1,168 @@
+// Package kilo provides authentication and token management functionality
+// for Kilo AI services.
+package kilo
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "time"
+)
+
+const (
+ // BaseURL is the base URL for the Kilo AI API.
+ BaseURL = "https://api.kilo.ai/api"
+)
+
+// DeviceAuthResponse represents the response from initiating device flow.
+type DeviceAuthResponse struct {
+ Code string `json:"code"`
+ VerificationURL string `json:"verificationUrl"`
+ ExpiresIn int `json:"expiresIn"`
+}
+
+// DeviceStatusResponse represents the response when polling for device flow status.
+type DeviceStatusResponse struct {
+ Status string `json:"status"`
+ Token string `json:"token"`
+ UserEmail string `json:"userEmail"`
+}
+
+// Profile represents the user profile from Kilo AI.
+type Profile struct {
+ Email string `json:"email"`
+ Orgs []Organization `json:"organizations"`
+}
+
+// Organization represents a Kilo AI organization.
+type Organization struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+}
+
+// Defaults represents default settings for an organization or user.
+type Defaults struct {
+ Model string `json:"model"`
+}
+
+// KiloAuth provides methods for handling the Kilo AI authentication flow.
+type KiloAuth struct {
+ client *http.Client
+}
+
+// NewKiloAuth creates a new instance of KiloAuth.
+func NewKiloAuth() *KiloAuth {
+ return &KiloAuth{
+ client: &http.Client{Timeout: 30 * time.Second},
+ }
+}
+
+// InitiateDeviceFlow starts the device authentication flow.
+func (k *KiloAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceAuthResponse, error) {
+ resp, err := k.client.Post(BaseURL+"/device-auth/codes", "application/json", nil)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to initiate device flow: status %d", resp.StatusCode)
+ }
+
+ var data DeviceAuthResponse
+ if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
+ return nil, err
+ }
+ return &data, nil
+}
+
+// PollForToken polls for the device flow completion.
+func (k *KiloAuth) PollForToken(ctx context.Context, code string) (*DeviceStatusResponse, error) {
+ ticker := time.NewTicker(5 * time.Second)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-ticker.C:
+ resp, err := k.client.Get(BaseURL + "/device-auth/codes/" + code)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ var data DeviceStatusResponse
+ if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
+ return nil, err
+ }
+
+ switch data.Status {
+ case "approved":
+ return &data, nil
+ case "denied", "expired":
+ return nil, fmt.Errorf("device flow %s", data.Status)
+ case "pending":
+ continue
+ default:
+ return nil, fmt.Errorf("unknown status: %s", data.Status)
+ }
+ }
+ }
+}
+
+// GetProfile fetches the user's profile.
+func (k *KiloAuth) GetProfile(ctx context.Context, token string) (*Profile, error) {
+ req, err := http.NewRequestWithContext(ctx, "GET", BaseURL+"/profile", nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create get profile request: %w", err)
+ }
+ req.Header.Set("Authorization", "Bearer "+token)
+
+ resp, err := k.client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to get profile: status %d", resp.StatusCode)
+ }
+
+ var profile Profile
+ if err := json.NewDecoder(resp.Body).Decode(&profile); err != nil {
+ return nil, err
+ }
+ return &profile, nil
+}
+
+// GetDefaults fetches default settings for an organization.
+func (k *KiloAuth) GetDefaults(ctx context.Context, token, orgID string) (*Defaults, error) {
+ url := BaseURL + "/defaults"
+ if orgID != "" {
+ url = BaseURL + "/organizations/" + orgID + "/defaults"
+ }
+
+ req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create get defaults request: %w", err)
+ }
+ req.Header.Set("Authorization", "Bearer "+token)
+
+ resp, err := k.client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("failed to get defaults: status %d", resp.StatusCode)
+ }
+
+ var defaults Defaults
+ if err := json.NewDecoder(resp.Body).Decode(&defaults); err != nil {
+ return nil, err
+ }
+ return &defaults, nil
+}
diff --git a/internal/auth/kilo/kilo_token.go b/internal/auth/kilo/kilo_token.go
new file mode 100644
index 00000000..5d1646e7
--- /dev/null
+++ b/internal/auth/kilo/kilo_token.go
@@ -0,0 +1,60 @@
+// Package kilo provides authentication and token management functionality
+// for Kilo AI services.
+package kilo
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
+ log "github.com/sirupsen/logrus"
+)
+
+// KiloTokenStorage stores token information for Kilo AI authentication.
+type KiloTokenStorage struct {
+ // Token is the Kilo access token.
+ Token string `json:"kilocodeToken"`
+
+ // OrganizationID is the Kilo organization ID.
+ OrganizationID string `json:"kilocodeOrganizationId"`
+
+ // Model is the default model to use.
+ Model string `json:"kilocodeModel"`
+
+ // Email is the email address of the authenticated user.
+ Email string `json:"email"`
+
+ // Type indicates the authentication provider type, always "kilo" for this storage.
+ Type string `json:"type"`
+}
+
+// SaveTokenToFile serializes the Kilo token storage to a JSON file.
+func (ts *KiloTokenStorage) SaveTokenToFile(authFilePath string) error {
+ misc.LogSavingCredentials(authFilePath)
+ ts.Type = "kilo"
+ if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
+ return fmt.Errorf("failed to create directory: %v", err)
+ }
+
+ f, err := os.Create(authFilePath)
+ if err != nil {
+ return fmt.Errorf("failed to create token file: %w", err)
+ }
+ defer func() {
+ if errClose := f.Close(); errClose != nil {
+ log.Errorf("failed to close file: %v", errClose)
+ }
+ }()
+
+ if err = json.NewEncoder(f).Encode(ts); err != nil {
+ return fmt.Errorf("failed to write token to file: %w", err)
+ }
+ return nil
+}
+
+// CredentialFileName returns the filename used to persist Kilo credentials.
+func CredentialFileName(email string) string {
+ return fmt.Sprintf("kilo-%s.json", email)
+}
diff --git a/internal/auth/kiro/aws.go b/internal/auth/kiro/aws.go
new file mode 100644
index 00000000..6ec67c49
--- /dev/null
+++ b/internal/auth/kiro/aws.go
@@ -0,0 +1,522 @@
+// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
+// It includes interfaces and implementations for token storage and authentication methods.
+package kiro
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+)
+
+// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
+type PKCECodes struct {
+ // CodeVerifier is the cryptographically random string used to correlate
+ // the authorization request to the token request
+ CodeVerifier string `json:"code_verifier"`
+ // CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded
+ CodeChallenge string `json:"code_challenge"`
+}
+
+// KiroTokenData holds OAuth token information from AWS CodeWhisperer (Kiro)
+type KiroTokenData struct {
+ // AccessToken is the OAuth2 access token for API access
+ AccessToken string `json:"accessToken"`
+ // RefreshToken is used to obtain new access tokens
+ RefreshToken string `json:"refreshToken"`
+ // ProfileArn is the AWS CodeWhisperer profile ARN
+ ProfileArn string `json:"profileArn"`
+ // ExpiresAt is the timestamp when the token expires
+ ExpiresAt string `json:"expiresAt"`
+ // AuthMethod indicates the authentication method used (e.g., "builder-id", "social", "idc")
+ AuthMethod string `json:"authMethod"`
+ // Provider indicates the OAuth provider (e.g., "AWS", "Google", "Enterprise")
+ Provider string `json:"provider"`
+ // ClientID is the OIDC client ID (needed for token refresh)
+ ClientID string `json:"clientId,omitempty"`
+ // ClientSecret is the OIDC client secret (needed for token refresh)
+ ClientSecret string `json:"clientSecret,omitempty"`
+ // ClientIDHash is the hash of client ID used to locate device registration file
+ // (Enterprise Kiro IDE stores clientId/clientSecret in ~/.aws/sso/cache/{clientIdHash}.json)
+ ClientIDHash string `json:"clientIdHash,omitempty"`
+ // Email is the user's email address (used for file naming)
+ Email string `json:"email,omitempty"`
+ // StartURL is the IDC/Identity Center start URL (only for IDC auth method)
+ StartURL string `json:"startUrl,omitempty"`
+ // Region is the AWS region for IDC authentication (only for IDC auth method)
+ Region string `json:"region,omitempty"`
+}
+
+// KiroAuthBundle aggregates authentication data after OAuth flow completion
+type KiroAuthBundle struct {
+ // TokenData contains the OAuth tokens from the authentication flow
+ TokenData KiroTokenData `json:"token_data"`
+ // LastRefresh is the timestamp of the last token refresh
+ LastRefresh string `json:"last_refresh"`
+}
+
+// KiroUsageInfo represents usage information from CodeWhisperer API
+type KiroUsageInfo struct {
+ // SubscriptionTitle is the subscription plan name (e.g., "KIRO FREE")
+ SubscriptionTitle string `json:"subscription_title"`
+ // CurrentUsage is the current credit usage
+ CurrentUsage float64 `json:"current_usage"`
+ // UsageLimit is the maximum credit limit
+ UsageLimit float64 `json:"usage_limit"`
+ // NextReset is the timestamp of the next usage reset
+ NextReset string `json:"next_reset"`
+}
+
+// KiroModel represents a model available through the CodeWhisperer API
+type KiroModel struct {
+ // ModelID is the unique identifier for the model
+ ModelID string `json:"modelId"`
+ // ModelName is the human-readable name
+ ModelName string `json:"modelName"`
+ // Description is the model description
+ Description string `json:"description"`
+ // RateMultiplier is the credit multiplier for this model
+ RateMultiplier float64 `json:"rateMultiplier"`
+ // RateUnit is the unit for rate calculation (e.g., "credit")
+ RateUnit string `json:"rateUnit"`
+ // MaxInputTokens is the maximum input token limit
+ MaxInputTokens int `json:"maxInputTokens,omitempty"`
+}
+
+// KiroIDETokenFile is the default path to Kiro IDE's token file
+const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
+
+// Default retry configuration for file reading
+const (
+ defaultTokenReadMaxAttempts = 10 // Maximum retry attempts
+ defaultTokenReadBaseDelay = 50 * time.Millisecond // Base delay between retries
+)
+
+// isTransientFileError checks if the error is a transient file access error
+// that may be resolved by retrying (e.g., file locked by another process on Windows).
+func isTransientFileError(err error) bool {
+ if err == nil {
+ return false
+ }
+
+ // Check for OS-level file access errors (Windows sharing violation, etc.)
+ var pathErr *os.PathError
+ if errors.As(err, &pathErr) {
+ // Windows sharing violation (ERROR_SHARING_VIOLATION = 32)
+ // Windows lock violation (ERROR_LOCK_VIOLATION = 33)
+ errStr := pathErr.Err.Error()
+ if strings.Contains(errStr, "being used by another process") ||
+ strings.Contains(errStr, "sharing violation") ||
+ strings.Contains(errStr, "lock violation") {
+ return true
+ }
+ }
+
+ // Check error message for common transient patterns
+ errMsg := strings.ToLower(err.Error())
+ transientPatterns := []string{
+ "being used by another process",
+ "sharing violation",
+ "lock violation",
+ "access is denied",
+ "unexpected end of json",
+ "unexpected eof",
+ }
+ for _, pattern := range transientPatterns {
+ if strings.Contains(errMsg, pattern) {
+ return true
+ }
+ }
+
+ return false
+}
+
+// LoadKiroIDETokenWithRetry loads token data from Kiro IDE's token file with retry logic.
+// This handles transient file access errors (e.g., file locked by Kiro IDE during write).
+// maxAttempts: maximum number of retry attempts (default 10 if <= 0)
+// baseDelay: base delay between retries with exponential backoff (default 50ms if <= 0)
+func LoadKiroIDETokenWithRetry(maxAttempts int, baseDelay time.Duration) (*KiroTokenData, error) {
+ if maxAttempts <= 0 {
+ maxAttempts = defaultTokenReadMaxAttempts
+ }
+ if baseDelay <= 0 {
+ baseDelay = defaultTokenReadBaseDelay
+ }
+
+ var lastErr error
+ for attempt := 0; attempt < maxAttempts; attempt++ {
+ token, err := LoadKiroIDEToken()
+ if err == nil {
+ return token, nil
+ }
+ lastErr = err
+
+ // Only retry for transient errors
+ if !isTransientFileError(err) {
+ return nil, err
+ }
+
+ // Exponential backoff: delay * 2^attempt, capped at 500ms
+ delay := baseDelay * time.Duration(1< 500*time.Millisecond {
+ delay = 500 * time.Millisecond
+ }
+ time.Sleep(delay)
+ }
+
+ return nil, fmt.Errorf("failed to read token file after %d attempts: %w", maxAttempts, lastErr)
+}
+
+// LoadKiroIDEToken loads token data from Kiro IDE's token file.
+// For Enterprise Kiro IDE (IDC auth), it also loads clientId and clientSecret
+// from the device registration file referenced by clientIdHash.
+func LoadKiroIDEToken() (*KiroTokenData, error) {
+ homeDir, err := os.UserHomeDir()
+ if err != nil {
+ return nil, fmt.Errorf("failed to get home directory: %w", err)
+ }
+
+ tokenPath := filepath.Join(homeDir, KiroIDETokenFile)
+ data, err := os.ReadFile(tokenPath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read Kiro IDE token file (%s): %w", tokenPath, err)
+ }
+
+ var token KiroTokenData
+ if err := json.Unmarshal(data, &token); err != nil {
+ return nil, fmt.Errorf("failed to parse Kiro IDE token: %w", err)
+ }
+
+ if token.AccessToken == "" {
+ return nil, fmt.Errorf("access token is empty in Kiro IDE token file")
+ }
+
+ // Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
+ token.AuthMethod = strings.ToLower(token.AuthMethod)
+
+ // For Enterprise Kiro IDE (IDC auth), load clientId and clientSecret from device registration
+ // The device registration file is located at ~/.aws/sso/cache/{clientIdHash}.json
+ if token.ClientIDHash != "" && token.ClientID == "" {
+ if err := loadDeviceRegistration(homeDir, token.ClientIDHash, &token); err != nil {
+ // Log warning but don't fail - token might still work for some operations
+ fmt.Printf("warning: failed to load device registration for clientIdHash %s: %v\n", token.ClientIDHash, err)
+ }
+ }
+
+ return &token, nil
+}
+
+// loadDeviceRegistration loads clientId and clientSecret from the device registration file.
+// Enterprise Kiro IDE stores these in ~/.aws/sso/cache/{clientIdHash}.json
+func loadDeviceRegistration(homeDir, clientIDHash string, token *KiroTokenData) error {
+ if clientIDHash == "" {
+ return fmt.Errorf("clientIdHash is empty")
+ }
+
+ // Sanitize clientIdHash to prevent path traversal
+ if strings.Contains(clientIDHash, "/") || strings.Contains(clientIDHash, "\\") || strings.Contains(clientIDHash, "..") {
+ return fmt.Errorf("invalid clientIdHash: contains path separator")
+ }
+
+ deviceRegPath := filepath.Join(homeDir, ".aws", "sso", "cache", clientIDHash+".json")
+ data, err := os.ReadFile(deviceRegPath)
+ if err != nil {
+ return fmt.Errorf("failed to read device registration file (%s): %w", deviceRegPath, err)
+ }
+
+ // Device registration file structure
+ var deviceReg struct {
+ ClientID string `json:"clientId"`
+ ClientSecret string `json:"clientSecret"`
+ ExpiresAt string `json:"expiresAt"`
+ }
+
+ if err := json.Unmarshal(data, &deviceReg); err != nil {
+ return fmt.Errorf("failed to parse device registration: %w", err)
+ }
+
+ if deviceReg.ClientID == "" || deviceReg.ClientSecret == "" {
+ return fmt.Errorf("device registration missing clientId or clientSecret")
+ }
+
+ token.ClientID = deviceReg.ClientID
+ token.ClientSecret = deviceReg.ClientSecret
+
+ return nil
+}
+
+// LoadKiroTokenFromPath loads token data from a custom path.
+// This supports multiple accounts by allowing different token files.
+// For Enterprise Kiro IDE (IDC auth), it also loads clientId and clientSecret
+// from the device registration file referenced by clientIdHash.
+func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) {
+ homeDir, err := os.UserHomeDir()
+ if err != nil {
+ return nil, fmt.Errorf("failed to get home directory: %w", err)
+ }
+
+ // Expand ~ to home directory
+ if len(tokenPath) > 0 && tokenPath[0] == '~' {
+ tokenPath = filepath.Join(homeDir, tokenPath[1:])
+ }
+
+ data, err := os.ReadFile(tokenPath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read token file (%s): %w", tokenPath, err)
+ }
+
+ var token KiroTokenData
+ if err := json.Unmarshal(data, &token); err != nil {
+ return nil, fmt.Errorf("failed to parse token file: %w", err)
+ }
+
+ if token.AccessToken == "" {
+ return nil, fmt.Errorf("access token is empty in token file")
+ }
+
+ // Normalize AuthMethod to lowercase (Kiro IDE uses "IdC" but we expect "idc")
+ token.AuthMethod = strings.ToLower(token.AuthMethod)
+
+ // For Enterprise Kiro IDE (IDC auth), load clientId and clientSecret from device registration
+ if token.ClientIDHash != "" && token.ClientID == "" {
+ if err := loadDeviceRegistration(homeDir, token.ClientIDHash, &token); err != nil {
+ // Log warning but don't fail - token might still work for some operations
+ fmt.Printf("warning: failed to load device registration for clientIdHash %s: %v\n", token.ClientIDHash, err)
+ }
+ }
+
+ return &token, nil
+}
+
+// ListKiroTokenFiles lists all Kiro token files in the cache directory.
+// This supports multiple accounts by finding all token files.
+func ListKiroTokenFiles() ([]string, error) {
+ homeDir, err := os.UserHomeDir()
+ if err != nil {
+ return nil, fmt.Errorf("failed to get home directory: %w", err)
+ }
+
+ cacheDir := filepath.Join(homeDir, ".aws", "sso", "cache")
+
+ // Check if directory exists
+ if _, err := os.Stat(cacheDir); os.IsNotExist(err) {
+ return nil, nil // No token files
+ }
+
+ entries, err := os.ReadDir(cacheDir)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read cache directory: %w", err)
+ }
+
+ var tokenFiles []string
+ for _, entry := range entries {
+ if entry.IsDir() {
+ continue
+ }
+ name := entry.Name()
+ // Look for kiro token files only (avoid matching unrelated AWS SSO cache files)
+ if strings.HasSuffix(name, ".json") && strings.HasPrefix(name, "kiro") {
+ tokenFiles = append(tokenFiles, filepath.Join(cacheDir, name))
+ }
+ }
+
+ return tokenFiles, nil
+}
+
+// LoadAllKiroTokens loads all Kiro tokens from the cache directory.
+// This supports multiple accounts.
+func LoadAllKiroTokens() ([]*KiroTokenData, error) {
+ files, err := ListKiroTokenFiles()
+ if err != nil {
+ return nil, err
+ }
+
+ var tokens []*KiroTokenData
+ for _, file := range files {
+ token, err := LoadKiroTokenFromPath(file)
+ if err != nil {
+ // Skip invalid token files
+ continue
+ }
+ tokens = append(tokens, token)
+ }
+
+ return tokens, nil
+}
+
+// JWTClaims represents the claims we care about from a JWT token.
+// JWT tokens from Kiro/AWS contain user information in the payload.
+type JWTClaims struct {
+ Email string `json:"email,omitempty"`
+ Sub string `json:"sub,omitempty"`
+ PreferredUser string `json:"preferred_username,omitempty"`
+ Name string `json:"name,omitempty"`
+ Iss string `json:"iss,omitempty"`
+}
+
+// ExtractEmailFromJWT extracts the user's email from a JWT access token.
+// JWT tokens typically have format: header.payload.signature
+// The payload is base64url-encoded JSON containing user claims.
+func ExtractEmailFromJWT(accessToken string) string {
+ if accessToken == "" {
+ return ""
+ }
+
+ // JWT format: header.payload.signature
+ parts := strings.Split(accessToken, ".")
+ if len(parts) != 3 {
+ return ""
+ }
+
+ // Decode the payload (second part)
+ payload := parts[1]
+
+ // Add padding if needed (base64url requires padding)
+ switch len(payload) % 4 {
+ case 2:
+ payload += "=="
+ case 3:
+ payload += "="
+ }
+
+ decoded, err := base64.URLEncoding.DecodeString(payload)
+ if err != nil {
+ // Try RawURLEncoding (no padding)
+ decoded, err = base64.RawURLEncoding.DecodeString(parts[1])
+ if err != nil {
+ return ""
+ }
+ }
+
+ var claims JWTClaims
+ if err := json.Unmarshal(decoded, &claims); err != nil {
+ return ""
+ }
+
+ // Return email if available
+ if claims.Email != "" {
+ return claims.Email
+ }
+
+ // Fallback to preferred_username (some providers use this)
+ if claims.PreferredUser != "" && strings.Contains(claims.PreferredUser, "@") {
+ return claims.PreferredUser
+ }
+
+ // Fallback to sub if it looks like an email
+ if claims.Sub != "" && strings.Contains(claims.Sub, "@") {
+ return claims.Sub
+ }
+
+ return ""
+}
+
+// SanitizeEmailForFilename sanitizes an email address for use in a filename.
+// Replaces special characters with underscores and prevents path traversal attacks.
+// Also handles URL-encoded characters to prevent encoded path traversal attempts.
+func SanitizeEmailForFilename(email string) string {
+ if email == "" {
+ return ""
+ }
+
+ result := email
+
+ // First, handle URL-encoded path traversal attempts (%2F, %2E, %5C, etc.)
+ // This prevents encoded characters from bypassing the sanitization.
+ // Note: We replace % last to catch any remaining encodings including double-encoding (%252F)
+ result = strings.ReplaceAll(result, "%2F", "_") // /
+ result = strings.ReplaceAll(result, "%2f", "_")
+ result = strings.ReplaceAll(result, "%5C", "_") // \
+ result = strings.ReplaceAll(result, "%5c", "_")
+ result = strings.ReplaceAll(result, "%2E", "_") // .
+ result = strings.ReplaceAll(result, "%2e", "_")
+ result = strings.ReplaceAll(result, "%00", "_") // null byte
+ result = strings.ReplaceAll(result, "%", "_") // Catch remaining % to prevent double-encoding attacks
+
+ // Replace characters that are problematic in filenames
+ // Keep @ and . in middle but replace other special characters
+ for _, char := range []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " ", "\x00"} {
+ result = strings.ReplaceAll(result, char, "_")
+ }
+
+ // Prevent path traversal: replace leading dots in each path component
+ // This handles cases like "../../../etc/passwd" → "_.._.._.._etc_passwd"
+ parts := strings.Split(result, "_")
+ for i, part := range parts {
+ for strings.HasPrefix(part, ".") {
+ part = "_" + part[1:]
+ }
+ parts[i] = part
+ }
+ result = strings.Join(parts, "_")
+
+ return result
+}
+
+// ExtractIDCIdentifier extracts a unique identifier from IDC startUrl.
+// Examples:
+// - "https://d-1234567890.awsapps.com/start" -> "d-1234567890"
+// - "https://my-company.awsapps.com/start" -> "my-company"
+// - "https://acme-corp.awsapps.com/start" -> "acme-corp"
+func ExtractIDCIdentifier(startURL string) string {
+ if startURL == "" {
+ return ""
+ }
+
+ // Remove protocol prefix
+ url := strings.TrimPrefix(startURL, "https://")
+ url = strings.TrimPrefix(url, "http://")
+
+ // Extract subdomain (first part before the first dot)
+ // Format: {identifier}.awsapps.com/start
+ parts := strings.Split(url, ".")
+ if len(parts) > 0 && parts[0] != "" {
+ identifier := parts[0]
+ // Sanitize for filename safety
+ identifier = strings.ReplaceAll(identifier, "/", "_")
+ identifier = strings.ReplaceAll(identifier, "\\", "_")
+ identifier = strings.ReplaceAll(identifier, ":", "_")
+ return identifier
+ }
+
+ return ""
+}
+
+// GenerateTokenFileName generates a unique filename for token storage.
+// Priority: email > startUrl identifier (for IDC) > authMethod only
+// Email is unique, so no sequence suffix needed. Sequence is only added
+// when email is unavailable to prevent filename collisions.
+// Format: kiro-{authMethod}-{identifier}[-{seq}].json
+func GenerateTokenFileName(tokenData *KiroTokenData) string {
+ authMethod := tokenData.AuthMethod
+ if authMethod == "" {
+ authMethod = "unknown"
+ }
+
+ // Priority 1: Use email if available (no sequence needed, email is unique)
+ if tokenData.Email != "" {
+ // Sanitize email for filename (replace @ and . with -)
+ sanitizedEmail := tokenData.Email
+ sanitizedEmail = strings.ReplaceAll(sanitizedEmail, "@", "-")
+ sanitizedEmail = strings.ReplaceAll(sanitizedEmail, ".", "-")
+ return fmt.Sprintf("kiro-%s-%s.json", authMethod, sanitizedEmail)
+ }
+
+ // Generate sequence only when email is unavailable
+ seq := time.Now().UnixNano() % 100000
+
+ // Priority 2: For IDC, use startUrl identifier with sequence
+ if authMethod == "idc" && tokenData.StartURL != "" {
+ identifier := ExtractIDCIdentifier(tokenData.StartURL)
+ if identifier != "" {
+ return fmt.Sprintf("kiro-%s-%s-%05d.json", authMethod, identifier, seq)
+ }
+ }
+
+ // Priority 3: Fallback to authMethod only with sequence
+ return fmt.Sprintf("kiro-%s-%05d.json", authMethod, seq)
+}
diff --git a/internal/auth/kiro/aws_auth.go b/internal/auth/kiro/aws_auth.go
new file mode 100644
index 00000000..69ae2539
--- /dev/null
+++ b/internal/auth/kiro/aws_auth.go
@@ -0,0 +1,338 @@
+// Package kiro provides OAuth2 authentication functionality for AWS CodeWhisperer (Kiro) API.
+// This package implements token loading, refresh, and API communication with CodeWhisperer.
+package kiro
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ // awsKiroEndpoint is used for CodeWhisperer management APIs (GetUsageLimits, ListProfiles, etc.)
+ // Note: This is different from the Amazon Q streaming endpoint (q.us-east-1.amazonaws.com)
+ // used in kiro_executor.go for GenerateAssistantResponse. Both endpoints are correct
+ // for their respective API operations.
+ awsKiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com"
+ defaultTokenFile = "~/.aws/sso/cache/kiro-auth-token.json"
+ targetGetUsage = "AmazonCodeWhispererService.GetUsageLimits"
+ targetListModels = "AmazonCodeWhispererService.ListAvailableModels"
+ targetGenerateChat = "AmazonCodeWhispererStreamingService.GenerateAssistantResponse"
+)
+
+// KiroAuth handles AWS CodeWhisperer authentication and API communication.
+// It provides methods for loading tokens, refreshing expired tokens,
+// and communicating with the CodeWhisperer API.
+type KiroAuth struct {
+ httpClient *http.Client
+ endpoint string
+}
+
+// NewKiroAuth creates a new Kiro authentication service.
+// It initializes the HTTP client with proxy settings from the configuration.
+//
+// Parameters:
+// - cfg: The application configuration containing proxy settings
+//
+// Returns:
+// - *KiroAuth: A new Kiro authentication service instance
+func NewKiroAuth(cfg *config.Config) *KiroAuth {
+ return &KiroAuth{
+ httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 120 * time.Second}),
+ endpoint: awsKiroEndpoint,
+ }
+}
+
+// LoadTokenFromFile loads token data from a file path.
+// This method reads and parses the token file, expanding ~ to the home directory.
+//
+// Parameters:
+// - tokenFile: Path to the token file (supports ~ expansion)
+//
+// Returns:
+// - *KiroTokenData: The parsed token data
+// - error: An error if file reading or parsing fails
+func (k *KiroAuth) LoadTokenFromFile(tokenFile string) (*KiroTokenData, error) {
+ // Expand ~ to home directory
+ if strings.HasPrefix(tokenFile, "~") {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return nil, fmt.Errorf("failed to get home directory: %w", err)
+ }
+ tokenFile = filepath.Join(home, tokenFile[1:])
+ }
+
+ data, err := os.ReadFile(tokenFile)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read token file: %w", err)
+ }
+
+ var tokenData KiroTokenData
+ if err := json.Unmarshal(data, &tokenData); err != nil {
+ return nil, fmt.Errorf("failed to parse token file: %w", err)
+ }
+
+ return &tokenData, nil
+}
+
+// IsTokenExpired checks if the token has expired.
+// This method parses the expiration timestamp and compares it with the current time.
+//
+// Parameters:
+// - tokenData: The token data to check
+//
+// Returns:
+// - bool: True if the token has expired, false otherwise
+func (k *KiroAuth) IsTokenExpired(tokenData *KiroTokenData) bool {
+ if tokenData.ExpiresAt == "" {
+ return true
+ }
+
+ expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
+ if err != nil {
+ // Try alternate format
+ expiresAt, err = time.Parse("2006-01-02T15:04:05.000Z", tokenData.ExpiresAt)
+ if err != nil {
+ return true
+ }
+ }
+
+ return time.Now().After(expiresAt)
+}
+
+// makeRequest sends a request to the CodeWhisperer API.
+// This is an internal method for making authenticated API calls.
+//
+// Parameters:
+// - ctx: The context for the request
+// - target: The API target (e.g., "AmazonCodeWhispererService.GetUsageLimits")
+// - accessToken: The OAuth access token
+// - payload: The request payload
+//
+// Returns:
+// - []byte: The response body
+// - error: An error if the request fails
+func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken string, payload interface{}) ([]byte, error) {
+ jsonBody, err := json.Marshal(payload)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, k.endpoint, strings.NewReader(string(jsonBody)))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/x-amz-json-1.0")
+ req.Header.Set("x-amz-target", target)
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+ req.Header.Set("Accept", "application/json")
+
+ resp, err := k.httpClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("request failed: %w", err)
+ }
+ defer func() {
+ if errClose := resp.Body.Close(); errClose != nil {
+ log.Errorf("failed to close response body: %v", errClose)
+ }
+ }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
+ }
+
+ return body, nil
+}
+
+// GetUsageLimits retrieves usage information from the CodeWhisperer API.
+// This method fetches the current usage statistics and subscription information.
+//
+// Parameters:
+// - ctx: The context for the request
+// - tokenData: The token data containing access token and profile ARN
+//
+// Returns:
+// - *KiroUsageInfo: The usage information
+// - error: An error if the request fails
+func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) (*KiroUsageInfo, error) {
+ payload := map[string]interface{}{
+ "origin": "AI_EDITOR",
+ "profileArn": tokenData.ProfileArn,
+ "resourceType": "AGENTIC_REQUEST",
+ }
+
+ body, err := k.makeRequest(ctx, targetGetUsage, tokenData.AccessToken, payload)
+ if err != nil {
+ return nil, err
+ }
+
+ var result struct {
+ SubscriptionInfo struct {
+ SubscriptionTitle string `json:"subscriptionTitle"`
+ } `json:"subscriptionInfo"`
+ UsageBreakdownList []struct {
+ CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
+ UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
+ } `json:"usageBreakdownList"`
+ NextDateReset float64 `json:"nextDateReset"`
+ }
+
+ if err := json.Unmarshal(body, &result); err != nil {
+ return nil, fmt.Errorf("failed to parse usage response: %w", err)
+ }
+
+ usage := &KiroUsageInfo{
+ SubscriptionTitle: result.SubscriptionInfo.SubscriptionTitle,
+ NextReset: fmt.Sprintf("%v", result.NextDateReset),
+ }
+
+ if len(result.UsageBreakdownList) > 0 {
+ usage.CurrentUsage = result.UsageBreakdownList[0].CurrentUsageWithPrecision
+ usage.UsageLimit = result.UsageBreakdownList[0].UsageLimitWithPrecision
+ }
+
+ return usage, nil
+}
+
+// ListAvailableModels retrieves available models from the CodeWhisperer API.
+// This method fetches the list of AI models available for the authenticated user.
+//
+// Parameters:
+// - ctx: The context for the request
+// - tokenData: The token data containing access token and profile ARN
+//
+// Returns:
+// - []*KiroModel: The list of available models
+// - error: An error if the request fails
+func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroTokenData) ([]*KiroModel, error) {
+ payload := map[string]interface{}{
+ "origin": "AI_EDITOR",
+ "profileArn": tokenData.ProfileArn,
+ }
+
+ body, err := k.makeRequest(ctx, targetListModels, tokenData.AccessToken, payload)
+ if err != nil {
+ return nil, err
+ }
+
+ var result struct {
+ Models []struct {
+ ModelID string `json:"modelId"`
+ ModelName string `json:"modelName"`
+ Description string `json:"description"`
+ RateMultiplier float64 `json:"rateMultiplier"`
+ RateUnit string `json:"rateUnit"`
+ TokenLimits *struct {
+ MaxInputTokens int `json:"maxInputTokens"`
+ } `json:"tokenLimits"`
+ } `json:"models"`
+ }
+
+ if err := json.Unmarshal(body, &result); err != nil {
+ return nil, fmt.Errorf("failed to parse models response: %w", err)
+ }
+
+ models := make([]*KiroModel, 0, len(result.Models))
+ for _, m := range result.Models {
+ maxInputTokens := 0
+ if m.TokenLimits != nil {
+ maxInputTokens = m.TokenLimits.MaxInputTokens
+ }
+ models = append(models, &KiroModel{
+ ModelID: m.ModelID,
+ ModelName: m.ModelName,
+ Description: m.Description,
+ RateMultiplier: m.RateMultiplier,
+ RateUnit: m.RateUnit,
+ MaxInputTokens: maxInputTokens,
+ })
+ }
+
+ return models, nil
+}
+
+// CreateTokenStorage creates a new KiroTokenStorage from token data.
+// This method converts the token data into a storage structure suitable for persistence.
+//
+// Parameters:
+// - tokenData: The token data to convert
+//
+// Returns:
+// - *KiroTokenStorage: A new token storage instance
+func (k *KiroAuth) CreateTokenStorage(tokenData *KiroTokenData) *KiroTokenStorage {
+ return &KiroTokenStorage{
+ AccessToken: tokenData.AccessToken,
+ RefreshToken: tokenData.RefreshToken,
+ ProfileArn: tokenData.ProfileArn,
+ ExpiresAt: tokenData.ExpiresAt,
+ AuthMethod: tokenData.AuthMethod,
+ Provider: tokenData.Provider,
+ LastRefresh: time.Now().Format(time.RFC3339),
+ ClientID: tokenData.ClientID,
+ ClientSecret: tokenData.ClientSecret,
+ Region: tokenData.Region,
+ StartURL: tokenData.StartURL,
+ Email: tokenData.Email,
+ }
+}
+
+// ValidateToken checks if the token is valid by making a test API call.
+// This method verifies the token by attempting to fetch usage limits.
+//
+// Parameters:
+// - ctx: The context for the request
+// - tokenData: The token data to validate
+//
+// Returns:
+// - error: An error if the token is invalid
+func (k *KiroAuth) ValidateToken(ctx context.Context, tokenData *KiroTokenData) error {
+ _, err := k.GetUsageLimits(ctx, tokenData)
+ return err
+}
+
+// UpdateTokenStorage updates an existing token storage with new token data.
+// This method refreshes the token storage with newly obtained access and refresh tokens.
+//
+// Parameters:
+// - storage: The existing token storage to update
+// - tokenData: The new token data to apply
+func (k *KiroAuth) UpdateTokenStorage(storage *KiroTokenStorage, tokenData *KiroTokenData) {
+ storage.AccessToken = tokenData.AccessToken
+ storage.RefreshToken = tokenData.RefreshToken
+ storage.ProfileArn = tokenData.ProfileArn
+ storage.ExpiresAt = tokenData.ExpiresAt
+ storage.AuthMethod = tokenData.AuthMethod
+ storage.Provider = tokenData.Provider
+ storage.LastRefresh = time.Now().Format(time.RFC3339)
+ if tokenData.ClientID != "" {
+ storage.ClientID = tokenData.ClientID
+ }
+ if tokenData.ClientSecret != "" {
+ storage.ClientSecret = tokenData.ClientSecret
+ }
+ if tokenData.Region != "" {
+ storage.Region = tokenData.Region
+ }
+ if tokenData.StartURL != "" {
+ storage.StartURL = tokenData.StartURL
+ }
+ if tokenData.Email != "" {
+ storage.Email = tokenData.Email
+ }
+}
diff --git a/internal/auth/kiro/aws_test.go b/internal/auth/kiro/aws_test.go
new file mode 100644
index 00000000..194ad59e
--- /dev/null
+++ b/internal/auth/kiro/aws_test.go
@@ -0,0 +1,311 @@
+package kiro
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "testing"
+)
+
+func TestExtractEmailFromJWT(t *testing.T) {
+ tests := []struct {
+ name string
+ token string
+ expected string
+ }{
+ {
+ name: "Empty token",
+ token: "",
+ expected: "",
+ },
+ {
+ name: "Invalid token format",
+ token: "not.a.valid.jwt",
+ expected: "",
+ },
+ {
+ name: "Invalid token - not base64",
+ token: "xxx.yyy.zzz",
+ expected: "",
+ },
+ {
+ name: "Valid JWT with email",
+ token: createTestJWT(map[string]any{"email": "test@example.com", "sub": "user123"}),
+ expected: "test@example.com",
+ },
+ {
+ name: "JWT without email but with preferred_username",
+ token: createTestJWT(map[string]any{"preferred_username": "user@domain.com", "sub": "user123"}),
+ expected: "user@domain.com",
+ },
+ {
+ name: "JWT with email-like sub",
+ token: createTestJWT(map[string]any{"sub": "another@test.com"}),
+ expected: "another@test.com",
+ },
+ {
+ name: "JWT without any email fields",
+ token: createTestJWT(map[string]any{"sub": "user123", "name": "Test User"}),
+ expected: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := ExtractEmailFromJWT(tt.token)
+ if result != tt.expected {
+ t.Errorf("ExtractEmailFromJWT() = %q, want %q", result, tt.expected)
+ }
+ })
+ }
+}
+
+func TestSanitizeEmailForFilename(t *testing.T) {
+ tests := []struct {
+ name string
+ email string
+ expected string
+ }{
+ {
+ name: "Empty email",
+ email: "",
+ expected: "",
+ },
+ {
+ name: "Simple email",
+ email: "user@example.com",
+ expected: "user@example.com",
+ },
+ {
+ name: "Email with space",
+ email: "user name@example.com",
+ expected: "user_name@example.com",
+ },
+ {
+ name: "Email with special chars",
+ email: "user:name@example.com",
+ expected: "user_name@example.com",
+ },
+ {
+ name: "Email with multiple special chars",
+ email: "user/name:test@example.com",
+ expected: "user_name_test@example.com",
+ },
+ {
+ name: "Path traversal attempt",
+ email: "../../../etc/passwd",
+ expected: "_.__.__._etc_passwd",
+ },
+ {
+ name: "Path traversal with backslash",
+ email: `..\..\..\..\windows\system32`,
+ expected: "_.__.__.__._windows_system32",
+ },
+ {
+ name: "Null byte injection attempt",
+ email: "user\x00@evil.com",
+ expected: "user_@evil.com",
+ },
+ // URL-encoded path traversal tests
+ {
+ name: "URL-encoded slash",
+ email: "user%2Fpath@example.com",
+ expected: "user_path@example.com",
+ },
+ {
+ name: "URL-encoded backslash",
+ email: "user%5Cpath@example.com",
+ expected: "user_path@example.com",
+ },
+ {
+ name: "URL-encoded dot",
+ email: "%2E%2E%2Fetc%2Fpasswd",
+ expected: "___etc_passwd",
+ },
+ {
+ name: "URL-encoded null",
+ email: "user%00@evil.com",
+ expected: "user_@evil.com",
+ },
+ {
+ name: "Double URL-encoding attack",
+ email: "%252F%252E%252E",
+ expected: "_252F_252E_252E", // % replaced with _, remaining chars preserved (safe)
+ },
+ {
+ name: "Mixed case URL-encoding",
+ email: "%2f%2F%5c%5C",
+ expected: "____",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := SanitizeEmailForFilename(tt.email)
+ if result != tt.expected {
+ t.Errorf("SanitizeEmailForFilename() = %q, want %q", result, tt.expected)
+ }
+ })
+ }
+}
+
+// createTestJWT creates a test JWT token with the given claims
+func createTestJWT(claims map[string]any) string {
+ header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
+
+ payloadBytes, _ := json.Marshal(claims)
+ payload := base64.RawURLEncoding.EncodeToString(payloadBytes)
+
+ signature := base64.RawURLEncoding.EncodeToString([]byte("fake-signature"))
+
+ return header + "." + payload + "." + signature
+}
+
+func TestExtractIDCIdentifier(t *testing.T) {
+ tests := []struct {
+ name string
+ startURL string
+ expected string
+ }{
+ {
+ name: "Empty URL",
+ startURL: "",
+ expected: "",
+ },
+ {
+ name: "Standard IDC URL with d- prefix",
+ startURL: "https://d-1234567890.awsapps.com/start",
+ expected: "d-1234567890",
+ },
+ {
+ name: "IDC URL with company name",
+ startURL: "https://my-company.awsapps.com/start",
+ expected: "my-company",
+ },
+ {
+ name: "IDC URL with simple name",
+ startURL: "https://acme-corp.awsapps.com/start",
+ expected: "acme-corp",
+ },
+ {
+ name: "IDC URL without https",
+ startURL: "http://d-9876543210.awsapps.com/start",
+ expected: "d-9876543210",
+ },
+ {
+ name: "IDC URL with subdomain only",
+ startURL: "https://test.awsapps.com/start",
+ expected: "test",
+ },
+ {
+ name: "Builder ID URL",
+ startURL: "https://view.awsapps.com/start",
+ expected: "view",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := ExtractIDCIdentifier(tt.startURL)
+ if result != tt.expected {
+ t.Errorf("ExtractIDCIdentifier() = %q, want %q", result, tt.expected)
+ }
+ })
+ }
+}
+
+func TestGenerateTokenFileName(t *testing.T) {
+ tests := []struct {
+ name string
+ tokenData *KiroTokenData
+ expected string
+ }{
+ {
+ name: "IDC with email",
+ tokenData: &KiroTokenData{
+ AuthMethod: "idc",
+ Email: "user@example.com",
+ StartURL: "https://d-1234567890.awsapps.com/start",
+ },
+ expected: "kiro-idc-user-example-com.json",
+ },
+ {
+ name: "IDC without email but with startUrl",
+ tokenData: &KiroTokenData{
+ AuthMethod: "idc",
+ Email: "",
+ StartURL: "https://d-1234567890.awsapps.com/start",
+ },
+ expected: "kiro-idc-d-1234567890.json",
+ },
+ {
+ name: "IDC with company name in startUrl",
+ tokenData: &KiroTokenData{
+ AuthMethod: "idc",
+ Email: "",
+ StartURL: "https://my-company.awsapps.com/start",
+ },
+ expected: "kiro-idc-my-company.json",
+ },
+ {
+ name: "IDC without email and without startUrl",
+ tokenData: &KiroTokenData{
+ AuthMethod: "idc",
+ Email: "",
+ StartURL: "",
+ },
+ expected: "kiro-idc.json",
+ },
+ {
+ name: "Builder ID with email",
+ tokenData: &KiroTokenData{
+ AuthMethod: "builder-id",
+ Email: "user@gmail.com",
+ StartURL: "https://view.awsapps.com/start",
+ },
+ expected: "kiro-builder-id-user-gmail-com.json",
+ },
+ {
+ name: "Builder ID without email",
+ tokenData: &KiroTokenData{
+ AuthMethod: "builder-id",
+ Email: "",
+ StartURL: "https://view.awsapps.com/start",
+ },
+ expected: "kiro-builder-id.json",
+ },
+ {
+ name: "Social auth with email",
+ tokenData: &KiroTokenData{
+ AuthMethod: "google",
+ Email: "user@gmail.com",
+ },
+ expected: "kiro-google-user-gmail-com.json",
+ },
+ {
+ name: "Empty auth method",
+ tokenData: &KiroTokenData{
+ AuthMethod: "",
+ Email: "",
+ },
+ expected: "kiro-unknown.json",
+ },
+ {
+ name: "Email with special characters",
+ tokenData: &KiroTokenData{
+ AuthMethod: "idc",
+ Email: "user.name+tag@sub.example.com",
+ StartURL: "https://d-1234567890.awsapps.com/start",
+ },
+ expected: "kiro-idc-user-name+tag-sub-example-com.json",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := GenerateTokenFileName(tt.tokenData)
+ if result != tt.expected {
+ t.Errorf("GenerateTokenFileName() = %q, want %q", result, tt.expected)
+ }
+ })
+ }
+}
diff --git a/internal/auth/kiro/background_refresh.go b/internal/auth/kiro/background_refresh.go
new file mode 100644
index 00000000..d64c7475
--- /dev/null
+++ b/internal/auth/kiro/background_refresh.go
@@ -0,0 +1,247 @@
+package kiro
+
+import (
+ "context"
+ "log"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ "golang.org/x/sync/semaphore"
+)
+
+type Token struct {
+ ID string
+ AccessToken string
+ RefreshToken string
+ ExpiresAt time.Time
+ LastVerified time.Time
+ ClientID string
+ ClientSecret string
+ AuthMethod string
+ Provider string
+ StartURL string
+ Region string
+}
+
+type TokenRepository interface {
+ FindOldestUnverified(limit int) []*Token
+ UpdateToken(token *Token) error
+}
+
+type RefresherOption func(*BackgroundRefresher)
+
+func WithInterval(interval time.Duration) RefresherOption {
+ return func(r *BackgroundRefresher) {
+ r.interval = interval
+ }
+}
+
+func WithBatchSize(size int) RefresherOption {
+ return func(r *BackgroundRefresher) {
+ r.batchSize = size
+ }
+}
+
+func WithConcurrency(concurrency int) RefresherOption {
+ return func(r *BackgroundRefresher) {
+ r.concurrency = concurrency
+ }
+}
+
+type BackgroundRefresher struct {
+ interval time.Duration
+ batchSize int
+ concurrency int
+ tokenRepo TokenRepository
+ stopCh chan struct{}
+ wg sync.WaitGroup
+ oauth *KiroOAuth
+ ssoClient *SSOOIDCClient
+ callbackMu sync.RWMutex // 保护回调函数的并发访问
+ onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
+}
+
+func NewBackgroundRefresher(repo TokenRepository, opts ...RefresherOption) *BackgroundRefresher {
+ r := &BackgroundRefresher{
+ interval: time.Minute,
+ batchSize: 50,
+ concurrency: 10,
+ tokenRepo: repo,
+ stopCh: make(chan struct{}),
+ oauth: nil, // Lazy init - will be set when config available
+ ssoClient: nil, // Lazy init - will be set when config available
+ }
+ for _, opt := range opts {
+ opt(r)
+ }
+ return r
+}
+
+// WithConfig sets the configuration for OAuth and SSO clients.
+func WithConfig(cfg *config.Config) RefresherOption {
+ return func(r *BackgroundRefresher) {
+ r.oauth = NewKiroOAuth(cfg)
+ r.ssoClient = NewSSOOIDCClient(cfg)
+ }
+}
+
+// WithOnTokenRefreshed sets the callback function to be called when a token is successfully refreshed.
+// The callback receives the token ID (filename) and the new token data.
+// This allows external components (e.g., Watcher) to be notified of token updates.
+func WithOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) RefresherOption {
+ return func(r *BackgroundRefresher) {
+ r.callbackMu.Lock()
+ r.onTokenRefreshed = callback
+ r.callbackMu.Unlock()
+ }
+}
+
+func (r *BackgroundRefresher) Start(ctx context.Context) {
+ r.wg.Add(1)
+ go func() {
+ defer r.wg.Done()
+ ticker := time.NewTicker(r.interval)
+ defer ticker.Stop()
+
+ r.refreshBatch(ctx)
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-r.stopCh:
+ return
+ case <-ticker.C:
+ r.refreshBatch(ctx)
+ }
+ }
+ }()
+}
+
+func (r *BackgroundRefresher) Stop() {
+ close(r.stopCh)
+ r.wg.Wait()
+}
+
+func (r *BackgroundRefresher) refreshBatch(ctx context.Context) {
+ tokens := r.tokenRepo.FindOldestUnverified(r.batchSize)
+ if len(tokens) == 0 {
+ return
+ }
+
+ sem := semaphore.NewWeighted(int64(r.concurrency))
+ var wg sync.WaitGroup
+
+ for i, token := range tokens {
+ if i > 0 {
+ select {
+ case <-ctx.Done():
+ return
+ case <-r.stopCh:
+ return
+ case <-time.After(100 * time.Millisecond):
+ }
+ }
+
+ if err := sem.Acquire(ctx, 1); err != nil {
+ return
+ }
+
+ wg.Add(1)
+ go func(t *Token) {
+ defer wg.Done()
+ defer sem.Release(1)
+ r.refreshSingle(ctx, t)
+ }(token)
+ }
+
+ wg.Wait()
+}
+
+func (r *BackgroundRefresher) refreshSingle(ctx context.Context, token *Token) {
+ // Normalize auth method to lowercase for case-insensitive matching
+ authMethod := strings.ToLower(token.AuthMethod)
+
+ // Create refresh function based on auth method
+ refreshFunc := func(ctx context.Context) (*KiroTokenData, error) {
+ switch authMethod {
+ case "idc":
+ return r.ssoClient.RefreshTokenWithRegion(
+ ctx,
+ token.ClientID,
+ token.ClientSecret,
+ token.RefreshToken,
+ token.Region,
+ token.StartURL,
+ )
+ case "builder-id":
+ return r.ssoClient.RefreshToken(
+ ctx,
+ token.ClientID,
+ token.ClientSecret,
+ token.RefreshToken,
+ )
+ default:
+ return r.oauth.RefreshTokenWithFingerprint(ctx, token.RefreshToken, token.ID)
+ }
+ }
+
+ // Use graceful degradation for better reliability
+ result := RefreshWithGracefulDegradation(
+ ctx,
+ refreshFunc,
+ token.AccessToken,
+ token.ExpiresAt,
+ )
+
+ if result.Error != nil {
+ log.Printf("failed to refresh token %s: %v", token.ID, result.Error)
+ return
+ }
+
+ newTokenData := result.TokenData
+ if result.UsedFallback {
+ log.Printf("token %s: using existing token as fallback (refresh failed but token still valid)", token.ID)
+ // Don't update the token file if we're using fallback
+ // Just update LastVerified to prevent immediate re-check
+ token.LastVerified = time.Now()
+ return
+ }
+
+ token.AccessToken = newTokenData.AccessToken
+ if newTokenData.RefreshToken != "" {
+ token.RefreshToken = newTokenData.RefreshToken
+ }
+ token.LastVerified = time.Now()
+
+ if newTokenData.ExpiresAt != "" {
+ if expTime, parseErr := time.Parse(time.RFC3339, newTokenData.ExpiresAt); parseErr == nil {
+ token.ExpiresAt = expTime
+ }
+ }
+
+ if err := r.tokenRepo.UpdateToken(token); err != nil {
+ log.Printf("failed to update token %s: %v", token.ID, err)
+ return
+ }
+
+ // 方案 A: 刷新成功后触发回调,通知 Watcher 更新内存中的 Auth 对象
+ r.callbackMu.RLock()
+ callback := r.onTokenRefreshed
+ r.callbackMu.RUnlock()
+
+ if callback != nil {
+ // 使用 defer recover 隔离回调 panic,防止崩溃整个进程
+ func() {
+ defer func() {
+ if rec := recover(); rec != nil {
+ log.Printf("background refresh: callback panic for token %s: %v", token.ID, rec)
+ }
+ }()
+ log.Printf("background refresh: notifying token refresh callback for %s", token.ID)
+ callback(token.ID, newTokenData)
+ }()
+ }
+}
diff --git a/internal/auth/kiro/codewhisperer_client.go b/internal/auth/kiro/codewhisperer_client.go
new file mode 100644
index 00000000..0a7392e8
--- /dev/null
+++ b/internal/auth/kiro/codewhisperer_client.go
@@ -0,0 +1,166 @@
+// Package kiro provides CodeWhisperer API client for fetching user info.
+package kiro
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ codeWhispererAPI = "https://codewhisperer.us-east-1.amazonaws.com"
+ kiroVersion = "0.6.18"
+)
+
+// CodeWhispererClient handles CodeWhisperer API calls.
+type CodeWhispererClient struct {
+ httpClient *http.Client
+ machineID string
+}
+
+// UsageLimitsResponse represents the getUsageLimits API response.
+type UsageLimitsResponse struct {
+ DaysUntilReset *int `json:"daysUntilReset,omitempty"`
+ NextDateReset *float64 `json:"nextDateReset,omitempty"`
+ UserInfo *UserInfo `json:"userInfo,omitempty"`
+ SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
+ UsageBreakdownList []UsageBreakdown `json:"usageBreakdownList,omitempty"`
+}
+
+// UserInfo contains user information from the API.
+type UserInfo struct {
+ Email string `json:"email,omitempty"`
+ UserID string `json:"userId,omitempty"`
+}
+
+// SubscriptionInfo contains subscription details.
+type SubscriptionInfo struct {
+ SubscriptionTitle string `json:"subscriptionTitle,omitempty"`
+ Type string `json:"type,omitempty"`
+}
+
+// UsageBreakdown contains usage details.
+type UsageBreakdown struct {
+ UsageLimit *int `json:"usageLimit,omitempty"`
+ CurrentUsage *int `json:"currentUsage,omitempty"`
+ UsageLimitWithPrecision *float64 `json:"usageLimitWithPrecision,omitempty"`
+ CurrentUsageWithPrecision *float64 `json:"currentUsageWithPrecision,omitempty"`
+ NextDateReset *float64 `json:"nextDateReset,omitempty"`
+ DisplayName string `json:"displayName,omitempty"`
+ ResourceType string `json:"resourceType,omitempty"`
+}
+
+// NewCodeWhispererClient creates a new CodeWhisperer client.
+func NewCodeWhispererClient(cfg *config.Config, machineID string) *CodeWhispererClient {
+ client := &http.Client{Timeout: 30 * time.Second}
+ if cfg != nil {
+ client = util.SetProxy(&cfg.SDKConfig, client)
+ }
+ if machineID == "" {
+ machineID = uuid.New().String()
+ }
+ return &CodeWhispererClient{
+ httpClient: client,
+ machineID: machineID,
+ }
+}
+
+// generateInvocationID generates a unique invocation ID.
+func generateInvocationID() string {
+ return uuid.New().String()
+}
+
+// GetUsageLimits fetches usage limits and user info from CodeWhisperer API.
+// This is the recommended way to get user email after login.
+func (c *CodeWhispererClient) GetUsageLimits(ctx context.Context, accessToken string) (*UsageLimitsResponse, error) {
+ url := fmt.Sprintf("%s/getUsageLimits?isEmailRequired=true&origin=AI_EDITOR&resourceType=AGENTIC_REQUEST", codeWhispererAPI)
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ // Set headers to match Kiro IDE
+ xAmzUserAgent := fmt.Sprintf("aws-sdk-js/1.0.0 KiroIDE-%s-%s", kiroVersion, c.machineID)
+ userAgent := fmt.Sprintf("aws-sdk-js/1.0.0 ua/2.1 os/windows lang/js md/nodejs#20.16.0 api/codewhispererruntime#1.0.0 m/E KiroIDE-%s-%s", kiroVersion, c.machineID)
+
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+ req.Header.Set("x-amz-user-agent", xAmzUserAgent)
+ req.Header.Set("User-Agent", userAgent)
+ req.Header.Set("amz-sdk-invocation-id", generateInvocationID())
+ req.Header.Set("amz-sdk-request", "attempt=1; max=1")
+ req.Header.Set("Connection", "close")
+
+ log.Debugf("codewhisperer: GET %s", url)
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("request failed: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response: %w", err)
+ }
+
+ log.Debugf("codewhisperer: status=%d, body=%s", resp.StatusCode, string(body))
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
+ }
+
+ var result UsageLimitsResponse
+ if err := json.Unmarshal(body, &result); err != nil {
+ return nil, fmt.Errorf("failed to parse response: %w", err)
+ }
+
+ return &result, nil
+}
+
+// FetchUserEmailFromAPI fetches user email using CodeWhisperer getUsageLimits API.
+// This is more reliable than JWT parsing as it uses the official API.
+func (c *CodeWhispererClient) FetchUserEmailFromAPI(ctx context.Context, accessToken string) string {
+ resp, err := c.GetUsageLimits(ctx, accessToken)
+ if err != nil {
+ log.Debugf("codewhisperer: failed to get usage limits: %v", err)
+ return ""
+ }
+
+ if resp.UserInfo != nil && resp.UserInfo.Email != "" {
+ log.Debugf("codewhisperer: got email from API: %s", resp.UserInfo.Email)
+ return resp.UserInfo.Email
+ }
+
+ log.Debugf("codewhisperer: no email in response")
+ return ""
+}
+
+// FetchUserEmailWithFallback fetches user email with multiple fallback methods.
+// Priority: 1. CodeWhisperer API 2. userinfo endpoint 3. JWT parsing
+func FetchUserEmailWithFallback(ctx context.Context, cfg *config.Config, accessToken string) string {
+ // Method 1: Try CodeWhisperer API (most reliable)
+ cwClient := NewCodeWhispererClient(cfg, "")
+ email := cwClient.FetchUserEmailFromAPI(ctx, accessToken)
+ if email != "" {
+ return email
+ }
+
+ // Method 2: Try SSO OIDC userinfo endpoint
+ ssoClient := NewSSOOIDCClient(cfg)
+ email = ssoClient.FetchUserEmail(ctx, accessToken)
+ if email != "" {
+ return email
+ }
+
+ // Method 3: Fallback to JWT parsing
+ return ExtractEmailFromJWT(accessToken)
+}
diff --git a/internal/auth/kiro/cooldown.go b/internal/auth/kiro/cooldown.go
new file mode 100644
index 00000000..c1aabbcb
--- /dev/null
+++ b/internal/auth/kiro/cooldown.go
@@ -0,0 +1,112 @@
+package kiro
+
+import (
+ "sync"
+ "time"
+)
+
+const (
+ CooldownReason429 = "rate_limit_exceeded"
+ CooldownReasonSuspended = "account_suspended"
+ CooldownReasonQuotaExhausted = "quota_exhausted"
+
+ DefaultShortCooldown = 1 * time.Minute
+ MaxShortCooldown = 5 * time.Minute
+ LongCooldown = 24 * time.Hour
+)
+
+type CooldownManager struct {
+ mu sync.RWMutex
+ cooldowns map[string]time.Time
+ reasons map[string]string
+}
+
+func NewCooldownManager() *CooldownManager {
+ return &CooldownManager{
+ cooldowns: make(map[string]time.Time),
+ reasons: make(map[string]string),
+ }
+}
+
+func (cm *CooldownManager) SetCooldown(tokenKey string, duration time.Duration, reason string) {
+ cm.mu.Lock()
+ defer cm.mu.Unlock()
+ cm.cooldowns[tokenKey] = time.Now().Add(duration)
+ cm.reasons[tokenKey] = reason
+}
+
+func (cm *CooldownManager) IsInCooldown(tokenKey string) bool {
+ cm.mu.RLock()
+ defer cm.mu.RUnlock()
+ endTime, exists := cm.cooldowns[tokenKey]
+ if !exists {
+ return false
+ }
+ return time.Now().Before(endTime)
+}
+
+func (cm *CooldownManager) GetRemainingCooldown(tokenKey string) time.Duration {
+ cm.mu.RLock()
+ defer cm.mu.RUnlock()
+ endTime, exists := cm.cooldowns[tokenKey]
+ if !exists {
+ return 0
+ }
+ remaining := time.Until(endTime)
+ if remaining < 0 {
+ return 0
+ }
+ return remaining
+}
+
+func (cm *CooldownManager) GetCooldownReason(tokenKey string) string {
+ cm.mu.RLock()
+ defer cm.mu.RUnlock()
+ return cm.reasons[tokenKey]
+}
+
+func (cm *CooldownManager) ClearCooldown(tokenKey string) {
+ cm.mu.Lock()
+ defer cm.mu.Unlock()
+ delete(cm.cooldowns, tokenKey)
+ delete(cm.reasons, tokenKey)
+}
+
+func (cm *CooldownManager) CleanupExpired() {
+ cm.mu.Lock()
+ defer cm.mu.Unlock()
+ now := time.Now()
+ for tokenKey, endTime := range cm.cooldowns {
+ if now.After(endTime) {
+ delete(cm.cooldowns, tokenKey)
+ delete(cm.reasons, tokenKey)
+ }
+ }
+}
+
+func (cm *CooldownManager) StartCleanupRoutine(interval time.Duration, stopCh <-chan struct{}) {
+ ticker := time.NewTicker(interval)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ticker.C:
+ cm.CleanupExpired()
+ case <-stopCh:
+ return
+ }
+ }
+}
+
+func CalculateCooldownFor429(retryCount int) time.Duration {
+ duration := DefaultShortCooldown * time.Duration(1< MaxShortCooldown {
+ return MaxShortCooldown
+ }
+ return duration
+}
+
+func CalculateCooldownUntilNextDay() time.Duration {
+ now := time.Now()
+ nextDay := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
+ return time.Until(nextDay)
+}
diff --git a/internal/auth/kiro/cooldown_test.go b/internal/auth/kiro/cooldown_test.go
new file mode 100644
index 00000000..e0b35df4
--- /dev/null
+++ b/internal/auth/kiro/cooldown_test.go
@@ -0,0 +1,240 @@
+package kiro
+
+import (
+ "sync"
+ "testing"
+ "time"
+)
+
+func TestNewCooldownManager(t *testing.T) {
+ cm := NewCooldownManager()
+ if cm == nil {
+ t.Fatal("expected non-nil CooldownManager")
+ }
+ if cm.cooldowns == nil {
+ t.Error("expected non-nil cooldowns map")
+ }
+ if cm.reasons == nil {
+ t.Error("expected non-nil reasons map")
+ }
+}
+
+func TestSetCooldown(t *testing.T) {
+ cm := NewCooldownManager()
+ cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
+
+ if !cm.IsInCooldown("token1") {
+ t.Error("expected token to be in cooldown")
+ }
+ if cm.GetCooldownReason("token1") != CooldownReason429 {
+ t.Errorf("expected reason %s, got %s", CooldownReason429, cm.GetCooldownReason("token1"))
+ }
+}
+
+func TestIsInCooldown_NotSet(t *testing.T) {
+ cm := NewCooldownManager()
+ if cm.IsInCooldown("nonexistent") {
+ t.Error("expected non-existent token to not be in cooldown")
+ }
+}
+
+func TestIsInCooldown_Expired(t *testing.T) {
+ cm := NewCooldownManager()
+ cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
+
+ time.Sleep(10 * time.Millisecond)
+
+ if cm.IsInCooldown("token1") {
+ t.Error("expected expired cooldown to return false")
+ }
+}
+
+func TestGetRemainingCooldown(t *testing.T) {
+ cm := NewCooldownManager()
+ cm.SetCooldown("token1", 1*time.Second, CooldownReason429)
+
+ remaining := cm.GetRemainingCooldown("token1")
+ if remaining <= 0 || remaining > 1*time.Second {
+ t.Errorf("expected remaining cooldown between 0 and 1s, got %v", remaining)
+ }
+}
+
+func TestGetRemainingCooldown_NotSet(t *testing.T) {
+ cm := NewCooldownManager()
+ remaining := cm.GetRemainingCooldown("nonexistent")
+ if remaining != 0 {
+ t.Errorf("expected 0 remaining for non-existent, got %v", remaining)
+ }
+}
+
+func TestGetRemainingCooldown_Expired(t *testing.T) {
+ cm := NewCooldownManager()
+ cm.SetCooldown("token1", 1*time.Millisecond, CooldownReason429)
+
+ time.Sleep(10 * time.Millisecond)
+
+ remaining := cm.GetRemainingCooldown("token1")
+ if remaining != 0 {
+ t.Errorf("expected 0 remaining for expired, got %v", remaining)
+ }
+}
+
+func TestGetCooldownReason(t *testing.T) {
+ cm := NewCooldownManager()
+ cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
+
+ reason := cm.GetCooldownReason("token1")
+ if reason != CooldownReasonSuspended {
+ t.Errorf("expected reason %s, got %s", CooldownReasonSuspended, reason)
+ }
+}
+
+func TestGetCooldownReason_NotSet(t *testing.T) {
+ cm := NewCooldownManager()
+ reason := cm.GetCooldownReason("nonexistent")
+ if reason != "" {
+ t.Errorf("expected empty reason for non-existent, got %s", reason)
+ }
+}
+
+func TestClearCooldown(t *testing.T) {
+ cm := NewCooldownManager()
+ cm.SetCooldown("token1", 1*time.Minute, CooldownReason429)
+ cm.ClearCooldown("token1")
+
+ if cm.IsInCooldown("token1") {
+ t.Error("expected cooldown to be cleared")
+ }
+ if cm.GetCooldownReason("token1") != "" {
+ t.Error("expected reason to be cleared")
+ }
+}
+
+func TestClearCooldown_NonExistent(t *testing.T) {
+ cm := NewCooldownManager()
+ cm.ClearCooldown("nonexistent")
+}
+
+func TestCleanupExpired(t *testing.T) {
+ cm := NewCooldownManager()
+ cm.SetCooldown("expired1", 1*time.Millisecond, CooldownReason429)
+ cm.SetCooldown("expired2", 1*time.Millisecond, CooldownReason429)
+ cm.SetCooldown("active", 1*time.Hour, CooldownReason429)
+
+ time.Sleep(10 * time.Millisecond)
+ cm.CleanupExpired()
+
+ if cm.GetCooldownReason("expired1") != "" {
+ t.Error("expected expired1 to be cleaned up")
+ }
+ if cm.GetCooldownReason("expired2") != "" {
+ t.Error("expected expired2 to be cleaned up")
+ }
+ if cm.GetCooldownReason("active") != CooldownReason429 {
+ t.Error("expected active to remain")
+ }
+}
+
+func TestCalculateCooldownFor429_FirstRetry(t *testing.T) {
+ duration := CalculateCooldownFor429(0)
+ if duration != DefaultShortCooldown {
+ t.Errorf("expected %v for retry 0, got %v", DefaultShortCooldown, duration)
+ }
+}
+
+func TestCalculateCooldownFor429_Exponential(t *testing.T) {
+ d1 := CalculateCooldownFor429(1)
+ d2 := CalculateCooldownFor429(2)
+
+ if d2 <= d1 {
+ t.Errorf("expected d2 > d1, got d1=%v, d2=%v", d1, d2)
+ }
+}
+
+func TestCalculateCooldownFor429_MaxCap(t *testing.T) {
+ duration := CalculateCooldownFor429(10)
+ if duration > MaxShortCooldown {
+ t.Errorf("expected max %v, got %v", MaxShortCooldown, duration)
+ }
+}
+
+func TestCalculateCooldownUntilNextDay(t *testing.T) {
+ duration := CalculateCooldownUntilNextDay()
+ if duration <= 0 || duration > 24*time.Hour {
+ t.Errorf("expected duration between 0 and 24h, got %v", duration)
+ }
+}
+
+func TestCooldownManager_ConcurrentAccess(t *testing.T) {
+ cm := NewCooldownManager()
+ const numGoroutines = 50
+ const numOperations = 100
+
+ var wg sync.WaitGroup
+ wg.Add(numGoroutines)
+
+ for i := 0; i < numGoroutines; i++ {
+ go func(id int) {
+ defer wg.Done()
+ tokenKey := "token" + string(rune('a'+id%10))
+ for j := 0; j < numOperations; j++ {
+ switch j % 6 {
+ case 0:
+ cm.SetCooldown(tokenKey, time.Duration(j)*time.Millisecond, CooldownReason429)
+ case 1:
+ cm.IsInCooldown(tokenKey)
+ case 2:
+ cm.GetRemainingCooldown(tokenKey)
+ case 3:
+ cm.GetCooldownReason(tokenKey)
+ case 4:
+ cm.ClearCooldown(tokenKey)
+ case 5:
+ cm.CleanupExpired()
+ }
+ }
+ }(i)
+ }
+
+ wg.Wait()
+}
+
+func TestCooldownReasonConstants(t *testing.T) {
+ if CooldownReason429 != "rate_limit_exceeded" {
+ t.Errorf("unexpected CooldownReason429: %s", CooldownReason429)
+ }
+ if CooldownReasonSuspended != "account_suspended" {
+ t.Errorf("unexpected CooldownReasonSuspended: %s", CooldownReasonSuspended)
+ }
+ if CooldownReasonQuotaExhausted != "quota_exhausted" {
+ t.Errorf("unexpected CooldownReasonQuotaExhausted: %s", CooldownReasonQuotaExhausted)
+ }
+}
+
+func TestDefaultConstants(t *testing.T) {
+ if DefaultShortCooldown != 1*time.Minute {
+ t.Errorf("unexpected DefaultShortCooldown: %v", DefaultShortCooldown)
+ }
+ if MaxShortCooldown != 5*time.Minute {
+ t.Errorf("unexpected MaxShortCooldown: %v", MaxShortCooldown)
+ }
+ if LongCooldown != 24*time.Hour {
+ t.Errorf("unexpected LongCooldown: %v", LongCooldown)
+ }
+}
+
+func TestSetCooldown_OverwritesPrevious(t *testing.T) {
+ cm := NewCooldownManager()
+ cm.SetCooldown("token1", 1*time.Hour, CooldownReason429)
+ cm.SetCooldown("token1", 1*time.Minute, CooldownReasonSuspended)
+
+ reason := cm.GetCooldownReason("token1")
+ if reason != CooldownReasonSuspended {
+ t.Errorf("expected reason to be overwritten to %s, got %s", CooldownReasonSuspended, reason)
+ }
+
+ remaining := cm.GetRemainingCooldown("token1")
+ if remaining > 1*time.Minute {
+ t.Errorf("expected remaining <= 1 minute, got %v", remaining)
+ }
+}
diff --git a/internal/auth/kiro/fingerprint.go b/internal/auth/kiro/fingerprint.go
new file mode 100644
index 00000000..c35e62b2
--- /dev/null
+++ b/internal/auth/kiro/fingerprint.go
@@ -0,0 +1,197 @@
+package kiro
+
+import (
+ "crypto/sha256"
+ "encoding/hex"
+ "fmt"
+ "math/rand"
+ "net/http"
+ "sync"
+ "time"
+)
+
+// Fingerprint 多维度指纹信息
+type Fingerprint struct {
+ SDKVersion string // 1.0.20-1.0.27
+ OSType string // darwin/windows/linux
+ OSVersion string // 10.0.22621
+ NodeVersion string // 18.x/20.x/22.x
+ KiroVersion string // 0.3.x-0.8.x
+ KiroHash string // SHA256
+ AcceptLanguage string
+ ScreenResolution string // 1920x1080
+ ColorDepth int // 24
+ HardwareConcurrency int // CPU 核心数
+ TimezoneOffset int
+}
+
+// FingerprintManager 指纹管理器
+type FingerprintManager struct {
+ mu sync.RWMutex
+ fingerprints map[string]*Fingerprint // tokenKey -> fingerprint
+ rng *rand.Rand
+}
+
+var (
+ sdkVersions = []string{
+ "1.0.20", "1.0.21", "1.0.22", "1.0.23",
+ "1.0.24", "1.0.25", "1.0.26", "1.0.27",
+ }
+ osTypes = []string{"darwin", "windows", "linux"}
+ osVersions = map[string][]string{
+ "darwin": {"14.0", "14.1", "14.2", "14.3", "14.4", "14.5", "15.0", "15.1"},
+ "windows": {"10.0.19041", "10.0.19042", "10.0.19043", "10.0.19044", "10.0.22621", "10.0.22631"},
+ "linux": {"5.15.0", "6.1.0", "6.2.0", "6.5.0", "6.6.0", "6.8.0"},
+ }
+ nodeVersions = []string{
+ "18.17.0", "18.18.0", "18.19.0", "18.20.0",
+ "20.9.0", "20.10.0", "20.11.0", "20.12.0", "20.13.0",
+ "22.0.0", "22.1.0", "22.2.0", "22.3.0",
+ }
+ kiroVersions = []string{
+ "0.3.0", "0.3.1", "0.4.0", "0.4.1", "0.5.0", "0.5.1",
+ "0.6.0", "0.6.1", "0.7.0", "0.7.1", "0.8.0", "0.8.1",
+ }
+ acceptLanguages = []string{
+ "en-US,en;q=0.9",
+ "en-GB,en;q=0.9",
+ "zh-CN,zh;q=0.9,en;q=0.8",
+ "zh-TW,zh;q=0.9,en;q=0.8",
+ "ja-JP,ja;q=0.9,en;q=0.8",
+ "ko-KR,ko;q=0.9,en;q=0.8",
+ "de-DE,de;q=0.9,en;q=0.8",
+ "fr-FR,fr;q=0.9,en;q=0.8",
+ }
+ screenResolutions = []string{
+ "1920x1080", "2560x1440", "3840x2160",
+ "1366x768", "1440x900", "1680x1050",
+ "2560x1600", "3440x1440",
+ }
+ colorDepths = []int{24, 32}
+ hardwareConcurrencies = []int{4, 6, 8, 10, 12, 16, 20, 24, 32}
+ timezoneOffsets = []int{-480, -420, -360, -300, -240, 0, 60, 120, 480, 540}
+)
+
+// NewFingerprintManager 创建指纹管理器
+func NewFingerprintManager() *FingerprintManager {
+ return &FingerprintManager{
+ fingerprints: make(map[string]*Fingerprint),
+ rng: rand.New(rand.NewSource(time.Now().UnixNano())),
+ }
+}
+
+// GetFingerprint 获取或生成 Token 关联的指纹
+func (fm *FingerprintManager) GetFingerprint(tokenKey string) *Fingerprint {
+ fm.mu.RLock()
+ if fp, exists := fm.fingerprints[tokenKey]; exists {
+ fm.mu.RUnlock()
+ return fp
+ }
+ fm.mu.RUnlock()
+
+ fm.mu.Lock()
+ defer fm.mu.Unlock()
+
+ if fp, exists := fm.fingerprints[tokenKey]; exists {
+ return fp
+ }
+
+ fp := fm.generateFingerprint(tokenKey)
+ fm.fingerprints[tokenKey] = fp
+ return fp
+}
+
+// generateFingerprint 生成新的指纹
+func (fm *FingerprintManager) generateFingerprint(tokenKey string) *Fingerprint {
+ osType := fm.randomChoice(osTypes)
+ osVersion := fm.randomChoice(osVersions[osType])
+ kiroVersion := fm.randomChoice(kiroVersions)
+
+ fp := &Fingerprint{
+ SDKVersion: fm.randomChoice(sdkVersions),
+ OSType: osType,
+ OSVersion: osVersion,
+ NodeVersion: fm.randomChoice(nodeVersions),
+ KiroVersion: kiroVersion,
+ AcceptLanguage: fm.randomChoice(acceptLanguages),
+ ScreenResolution: fm.randomChoice(screenResolutions),
+ ColorDepth: fm.randomIntChoice(colorDepths),
+ HardwareConcurrency: fm.randomIntChoice(hardwareConcurrencies),
+ TimezoneOffset: fm.randomIntChoice(timezoneOffsets),
+ }
+
+ fp.KiroHash = fm.generateKiroHash(tokenKey, kiroVersion, osType)
+ return fp
+}
+
+// generateKiroHash 生成 Kiro Hash
+func (fm *FingerprintManager) generateKiroHash(tokenKey, kiroVersion, osType string) string {
+ data := fmt.Sprintf("%s:%s:%s:%d", tokenKey, kiroVersion, osType, time.Now().UnixNano())
+ hash := sha256.Sum256([]byte(data))
+ return hex.EncodeToString(hash[:])
+}
+
+// randomChoice 随机选择字符串
+func (fm *FingerprintManager) randomChoice(choices []string) string {
+ return choices[fm.rng.Intn(len(choices))]
+}
+
+// randomIntChoice 随机选择整数
+func (fm *FingerprintManager) randomIntChoice(choices []int) int {
+ return choices[fm.rng.Intn(len(choices))]
+}
+
+// ApplyToRequest 将指纹信息应用到 HTTP 请求头
+func (fp *Fingerprint) ApplyToRequest(req *http.Request) {
+ req.Header.Set("X-Kiro-SDK-Version", fp.SDKVersion)
+ req.Header.Set("X-Kiro-OS-Type", fp.OSType)
+ req.Header.Set("X-Kiro-OS-Version", fp.OSVersion)
+ req.Header.Set("X-Kiro-Node-Version", fp.NodeVersion)
+ req.Header.Set("X-Kiro-Version", fp.KiroVersion)
+ req.Header.Set("X-Kiro-Hash", fp.KiroHash)
+ req.Header.Set("Accept-Language", fp.AcceptLanguage)
+ req.Header.Set("X-Screen-Resolution", fp.ScreenResolution)
+ req.Header.Set("X-Color-Depth", fmt.Sprintf("%d", fp.ColorDepth))
+ req.Header.Set("X-Hardware-Concurrency", fmt.Sprintf("%d", fp.HardwareConcurrency))
+ req.Header.Set("X-Timezone-Offset", fmt.Sprintf("%d", fp.TimezoneOffset))
+}
+
+// RemoveFingerprint 移除 Token 关联的指纹
+func (fm *FingerprintManager) RemoveFingerprint(tokenKey string) {
+ fm.mu.Lock()
+ defer fm.mu.Unlock()
+ delete(fm.fingerprints, tokenKey)
+}
+
+// Count 返回当前管理的指纹数量
+func (fm *FingerprintManager) Count() int {
+ fm.mu.RLock()
+ defer fm.mu.RUnlock()
+ return len(fm.fingerprints)
+}
+
+// BuildUserAgent 构建 User-Agent 字符串 (Kiro IDE 风格)
+// 格式: aws-sdk-js/{SDKVersion} ua/2.1 os/{OSType}#{OSVersion} lang/js md/nodejs#{NodeVersion} api/codewhispererstreaming#{SDKVersion} m/E KiroIDE-{KiroVersion}-{KiroHash}
+func (fp *Fingerprint) BuildUserAgent() string {
+ return fmt.Sprintf(
+ "aws-sdk-js/%s ua/2.1 os/%s#%s lang/js md/nodejs#%s api/codewhispererstreaming#%s m/E KiroIDE-%s-%s",
+ fp.SDKVersion,
+ fp.OSType,
+ fp.OSVersion,
+ fp.NodeVersion,
+ fp.SDKVersion,
+ fp.KiroVersion,
+ fp.KiroHash,
+ )
+}
+
+// BuildAmzUserAgent 构建 X-Amz-User-Agent 字符串
+// 格式: aws-sdk-js/{SDKVersion} KiroIDE-{KiroVersion}-{KiroHash}
+func (fp *Fingerprint) BuildAmzUserAgent() string {
+ return fmt.Sprintf(
+ "aws-sdk-js/%s KiroIDE-%s-%s",
+ fp.SDKVersion,
+ fp.KiroVersion,
+ fp.KiroHash,
+ )
+}
diff --git a/internal/auth/kiro/fingerprint_test.go b/internal/auth/kiro/fingerprint_test.go
new file mode 100644
index 00000000..e0ae51f2
--- /dev/null
+++ b/internal/auth/kiro/fingerprint_test.go
@@ -0,0 +1,227 @@
+package kiro
+
+import (
+ "net/http"
+ "sync"
+ "testing"
+)
+
+func TestNewFingerprintManager(t *testing.T) {
+ fm := NewFingerprintManager()
+ if fm == nil {
+ t.Fatal("expected non-nil FingerprintManager")
+ }
+ if fm.fingerprints == nil {
+ t.Error("expected non-nil fingerprints map")
+ }
+ if fm.rng == nil {
+ t.Error("expected non-nil rng")
+ }
+}
+
+func TestGetFingerprint_NewToken(t *testing.T) {
+ fm := NewFingerprintManager()
+ fp := fm.GetFingerprint("token1")
+
+ if fp == nil {
+ t.Fatal("expected non-nil Fingerprint")
+ }
+ if fp.SDKVersion == "" {
+ t.Error("expected non-empty SDKVersion")
+ }
+ if fp.OSType == "" {
+ t.Error("expected non-empty OSType")
+ }
+ if fp.OSVersion == "" {
+ t.Error("expected non-empty OSVersion")
+ }
+ if fp.NodeVersion == "" {
+ t.Error("expected non-empty NodeVersion")
+ }
+ if fp.KiroVersion == "" {
+ t.Error("expected non-empty KiroVersion")
+ }
+ if fp.KiroHash == "" {
+ t.Error("expected non-empty KiroHash")
+ }
+ if fp.AcceptLanguage == "" {
+ t.Error("expected non-empty AcceptLanguage")
+ }
+ if fp.ScreenResolution == "" {
+ t.Error("expected non-empty ScreenResolution")
+ }
+ if fp.ColorDepth == 0 {
+ t.Error("expected non-zero ColorDepth")
+ }
+ if fp.HardwareConcurrency == 0 {
+ t.Error("expected non-zero HardwareConcurrency")
+ }
+}
+
+func TestGetFingerprint_SameTokenReturnsSameFingerprint(t *testing.T) {
+ fm := NewFingerprintManager()
+ fp1 := fm.GetFingerprint("token1")
+ fp2 := fm.GetFingerprint("token1")
+
+ if fp1 != fp2 {
+ t.Error("expected same fingerprint for same token")
+ }
+}
+
+func TestGetFingerprint_DifferentTokens(t *testing.T) {
+ fm := NewFingerprintManager()
+ fp1 := fm.GetFingerprint("token1")
+ fp2 := fm.GetFingerprint("token2")
+
+ if fp1 == fp2 {
+ t.Error("expected different fingerprints for different tokens")
+ }
+}
+
+func TestRemoveFingerprint(t *testing.T) {
+ fm := NewFingerprintManager()
+ fm.GetFingerprint("token1")
+ if fm.Count() != 1 {
+ t.Fatalf("expected count 1, got %d", fm.Count())
+ }
+
+ fm.RemoveFingerprint("token1")
+ if fm.Count() != 0 {
+ t.Errorf("expected count 0, got %d", fm.Count())
+ }
+}
+
+func TestRemoveFingerprint_NonExistent(t *testing.T) {
+ fm := NewFingerprintManager()
+ fm.RemoveFingerprint("nonexistent")
+ if fm.Count() != 0 {
+ t.Errorf("expected count 0, got %d", fm.Count())
+ }
+}
+
+func TestCount(t *testing.T) {
+ fm := NewFingerprintManager()
+ if fm.Count() != 0 {
+ t.Errorf("expected count 0, got %d", fm.Count())
+ }
+
+ fm.GetFingerprint("token1")
+ fm.GetFingerprint("token2")
+ fm.GetFingerprint("token3")
+
+ if fm.Count() != 3 {
+ t.Errorf("expected count 3, got %d", fm.Count())
+ }
+}
+
+func TestApplyToRequest(t *testing.T) {
+ fm := NewFingerprintManager()
+ fp := fm.GetFingerprint("token1")
+
+ req, _ := http.NewRequest("GET", "http://example.com", nil)
+ fp.ApplyToRequest(req)
+
+ if req.Header.Get("X-Kiro-SDK-Version") != fp.SDKVersion {
+ t.Error("X-Kiro-SDK-Version header mismatch")
+ }
+ if req.Header.Get("X-Kiro-OS-Type") != fp.OSType {
+ t.Error("X-Kiro-OS-Type header mismatch")
+ }
+ if req.Header.Get("X-Kiro-OS-Version") != fp.OSVersion {
+ t.Error("X-Kiro-OS-Version header mismatch")
+ }
+ if req.Header.Get("X-Kiro-Node-Version") != fp.NodeVersion {
+ t.Error("X-Kiro-Node-Version header mismatch")
+ }
+ if req.Header.Get("X-Kiro-Version") != fp.KiroVersion {
+ t.Error("X-Kiro-Version header mismatch")
+ }
+ if req.Header.Get("X-Kiro-Hash") != fp.KiroHash {
+ t.Error("X-Kiro-Hash header mismatch")
+ }
+ if req.Header.Get("Accept-Language") != fp.AcceptLanguage {
+ t.Error("Accept-Language header mismatch")
+ }
+ if req.Header.Get("X-Screen-Resolution") != fp.ScreenResolution {
+ t.Error("X-Screen-Resolution header mismatch")
+ }
+}
+
+func TestGetFingerprint_OSVersionMatchesOSType(t *testing.T) {
+ fm := NewFingerprintManager()
+
+ for i := 0; i < 20; i++ {
+ fp := fm.GetFingerprint("token" + string(rune('a'+i)))
+ validVersions := osVersions[fp.OSType]
+ found := false
+ for _, v := range validVersions {
+ if v == fp.OSVersion {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Errorf("OS version %s not valid for OS type %s", fp.OSVersion, fp.OSType)
+ }
+ }
+}
+
+func TestFingerprintManager_ConcurrentAccess(t *testing.T) {
+ fm := NewFingerprintManager()
+ const numGoroutines = 100
+ const numOperations = 100
+
+ var wg sync.WaitGroup
+ wg.Add(numGoroutines)
+
+ for i := 0; i < numGoroutines; i++ {
+ go func(id int) {
+ defer wg.Done()
+ for j := 0; j < numOperations; j++ {
+ tokenKey := "token" + string(rune('a'+id%26))
+ switch j % 4 {
+ case 0:
+ fm.GetFingerprint(tokenKey)
+ case 1:
+ fm.Count()
+ case 2:
+ fp := fm.GetFingerprint(tokenKey)
+ req, _ := http.NewRequest("GET", "http://example.com", nil)
+ fp.ApplyToRequest(req)
+ case 3:
+ fm.RemoveFingerprint(tokenKey)
+ }
+ }
+ }(i)
+ }
+
+ wg.Wait()
+}
+
+func TestKiroHashUniqueness(t *testing.T) {
+ fm := NewFingerprintManager()
+ hashes := make(map[string]bool)
+
+ for i := 0; i < 100; i++ {
+ fp := fm.GetFingerprint("token" + string(rune(i)))
+ if hashes[fp.KiroHash] {
+ t.Errorf("duplicate KiroHash detected: %s", fp.KiroHash)
+ }
+ hashes[fp.KiroHash] = true
+ }
+}
+
+func TestKiroHashFormat(t *testing.T) {
+ fm := NewFingerprintManager()
+ fp := fm.GetFingerprint("token1")
+
+ if len(fp.KiroHash) != 64 {
+ t.Errorf("expected KiroHash length 64 (SHA256 hex), got %d", len(fp.KiroHash))
+ }
+
+ for _, c := range fp.KiroHash {
+ if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
+ t.Errorf("invalid hex character in KiroHash: %c", c)
+ }
+ }
+}
diff --git a/internal/auth/kiro/jitter.go b/internal/auth/kiro/jitter.go
new file mode 100644
index 00000000..0569a8fb
--- /dev/null
+++ b/internal/auth/kiro/jitter.go
@@ -0,0 +1,174 @@
+package kiro
+
+import (
+ "math/rand"
+ "sync"
+ "time"
+)
+
+// Jitter configuration constants
+const (
+ // JitterPercent is the default percentage of jitter to apply (±30%)
+ JitterPercent = 0.30
+
+ // Human-like delay ranges
+ ShortDelayMin = 50 * time.Millisecond // Minimum for rapid consecutive operations
+ ShortDelayMax = 200 * time.Millisecond // Maximum for rapid consecutive operations
+ NormalDelayMin = 1 * time.Second // Minimum for normal thinking time
+ NormalDelayMax = 3 * time.Second // Maximum for normal thinking time
+ LongDelayMin = 5 * time.Second // Minimum for reading/resting
+ LongDelayMax = 10 * time.Second // Maximum for reading/resting
+
+ // Probability thresholds for human-like behavior
+ ShortDelayProbability = 0.20 // 20% chance of short delay (consecutive ops)
+ LongDelayProbability = 0.05 // 5% chance of long delay (reading/resting)
+ NormalDelayProbability = 0.75 // 75% chance of normal delay (thinking)
+)
+
+var (
+ jitterRand *rand.Rand
+ jitterRandOnce sync.Once
+ jitterMu sync.Mutex
+ lastRequestTime time.Time
+)
+
+// initJitterRand initializes the random number generator for jitter calculations.
+// Uses a time-based seed for unpredictable but reproducible randomness.
+func initJitterRand() {
+ jitterRandOnce.Do(func() {
+ jitterRand = rand.New(rand.NewSource(time.Now().UnixNano()))
+ })
+}
+
+// RandomDelay generates a random delay between min and max duration.
+// Thread-safe implementation using mutex protection.
+func RandomDelay(min, max time.Duration) time.Duration {
+ initJitterRand()
+ jitterMu.Lock()
+ defer jitterMu.Unlock()
+
+ if min >= max {
+ return min
+ }
+
+ rangeMs := max.Milliseconds() - min.Milliseconds()
+ randomMs := jitterRand.Int63n(rangeMs)
+ return min + time.Duration(randomMs)*time.Millisecond
+}
+
+// JitterDelay adds jitter to a base delay.
+// Applies ±jitterPercent variation to the base delay.
+// For example, JitterDelay(1*time.Second, 0.30) returns a value between 700ms and 1300ms.
+func JitterDelay(baseDelay time.Duration, jitterPercent float64) time.Duration {
+ initJitterRand()
+ jitterMu.Lock()
+ defer jitterMu.Unlock()
+
+ if jitterPercent <= 0 || jitterPercent > 1 {
+ jitterPercent = JitterPercent
+ }
+
+ // Calculate jitter range: base * jitterPercent
+ jitterRange := float64(baseDelay) * jitterPercent
+
+ // Generate random value in range [-jitterRange, +jitterRange]
+ jitter := (jitterRand.Float64()*2 - 1) * jitterRange
+
+ result := time.Duration(float64(baseDelay) + jitter)
+ if result < 0 {
+ return 0
+ }
+ return result
+}
+
+// JitterDelayDefault applies the default ±30% jitter to a base delay.
+func JitterDelayDefault(baseDelay time.Duration) time.Duration {
+ return JitterDelay(baseDelay, JitterPercent)
+}
+
+// HumanLikeDelay generates a delay that mimics human behavior patterns.
+// The delay is selected based on probability distribution:
+// - 20% chance: Short delay (50-200ms) - simulates consecutive rapid operations
+// - 75% chance: Normal delay (1-3s) - simulates thinking/reading time
+// - 5% chance: Long delay (5-10s) - simulates breaks/reading longer content
+//
+// Returns the delay duration (caller should call time.Sleep with this value).
+func HumanLikeDelay() time.Duration {
+ initJitterRand()
+ jitterMu.Lock()
+ defer jitterMu.Unlock()
+
+ // Track time since last request for adaptive behavior
+ now := time.Now()
+ timeSinceLastRequest := now.Sub(lastRequestTime)
+ lastRequestTime = now
+
+ // If requests are very close together, use short delay
+ if timeSinceLastRequest < 500*time.Millisecond && timeSinceLastRequest > 0 {
+ rangeMs := ShortDelayMax.Milliseconds() - ShortDelayMin.Milliseconds()
+ randomMs := jitterRand.Int63n(rangeMs)
+ return ShortDelayMin + time.Duration(randomMs)*time.Millisecond
+ }
+
+ // Otherwise, use probability-based selection
+ roll := jitterRand.Float64()
+
+ var min, max time.Duration
+ switch {
+ case roll < ShortDelayProbability:
+ // Short delay - consecutive operations
+ min, max = ShortDelayMin, ShortDelayMax
+ case roll < ShortDelayProbability+LongDelayProbability:
+ // Long delay - reading/resting
+ min, max = LongDelayMin, LongDelayMax
+ default:
+ // Normal delay - thinking time
+ min, max = NormalDelayMin, NormalDelayMax
+ }
+
+ rangeMs := max.Milliseconds() - min.Milliseconds()
+ randomMs := jitterRand.Int63n(rangeMs)
+ return min + time.Duration(randomMs)*time.Millisecond
+}
+
+// ApplyHumanLikeDelay applies human-like delay by sleeping.
+// This is a convenience function that combines HumanLikeDelay with time.Sleep.
+func ApplyHumanLikeDelay() {
+ delay := HumanLikeDelay()
+ if delay > 0 {
+ time.Sleep(delay)
+ }
+}
+
+// ExponentialBackoffWithJitter calculates retry delay using exponential backoff with jitter.
+// Formula: min(baseDelay * 2^attempt + jitter, maxDelay)
+// This helps prevent thundering herd problem when multiple clients retry simultaneously.
+func ExponentialBackoffWithJitter(attempt int, baseDelay, maxDelay time.Duration) time.Duration {
+ if attempt < 0 {
+ attempt = 0
+ }
+
+ // Calculate exponential backoff: baseDelay * 2^attempt
+ backoff := baseDelay * time.Duration(1< maxDelay {
+ backoff = maxDelay
+ }
+
+ // Add ±30% jitter
+ return JitterDelay(backoff, JitterPercent)
+}
+
+// ShouldSkipDelay determines if delay should be skipped based on context.
+// Returns true for streaming responses, WebSocket connections, etc.
+// This function can be extended to check additional skip conditions.
+func ShouldSkipDelay(isStreaming bool) bool {
+ return isStreaming
+}
+
+// ResetLastRequestTime resets the last request time tracker.
+// Useful for testing or when starting a new session.
+func ResetLastRequestTime() {
+ jitterMu.Lock()
+ defer jitterMu.Unlock()
+ lastRequestTime = time.Time{}
+}
diff --git a/internal/auth/kiro/metrics.go b/internal/auth/kiro/metrics.go
new file mode 100644
index 00000000..0fe2d0c6
--- /dev/null
+++ b/internal/auth/kiro/metrics.go
@@ -0,0 +1,187 @@
+package kiro
+
+import (
+ "math"
+ "sync"
+ "time"
+)
+
+// TokenMetrics holds performance metrics for a single token.
+type TokenMetrics struct {
+ SuccessRate float64 // Success rate (0.0 - 1.0)
+ AvgLatency float64 // Average latency in milliseconds
+ QuotaRemaining float64 // Remaining quota (0.0 - 1.0)
+ LastUsed time.Time // Last usage timestamp
+ FailCount int // Consecutive failure count
+ TotalRequests int // Total request count
+ successCount int // Internal: successful request count
+ totalLatency float64 // Internal: cumulative latency
+}
+
+// TokenScorer manages token metrics and scoring.
+type TokenScorer struct {
+ mu sync.RWMutex
+ metrics map[string]*TokenMetrics
+
+ // Scoring weights
+ successRateWeight float64
+ quotaWeight float64
+ latencyWeight float64
+ lastUsedWeight float64
+ failPenaltyMultiplier float64
+}
+
+// NewTokenScorer creates a new TokenScorer with default weights.
+func NewTokenScorer() *TokenScorer {
+ return &TokenScorer{
+ metrics: make(map[string]*TokenMetrics),
+ successRateWeight: 0.4,
+ quotaWeight: 0.25,
+ latencyWeight: 0.2,
+ lastUsedWeight: 0.15,
+ failPenaltyMultiplier: 0.1,
+ }
+}
+
+// getOrCreateMetrics returns existing metrics or creates new ones.
+func (s *TokenScorer) getOrCreateMetrics(tokenKey string) *TokenMetrics {
+ if m, ok := s.metrics[tokenKey]; ok {
+ return m
+ }
+ m := &TokenMetrics{
+ SuccessRate: 1.0,
+ QuotaRemaining: 1.0,
+ }
+ s.metrics[tokenKey] = m
+ return m
+}
+
+// RecordRequest records the result of a request for a token.
+func (s *TokenScorer) RecordRequest(tokenKey string, success bool, latency time.Duration) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ m := s.getOrCreateMetrics(tokenKey)
+ m.TotalRequests++
+ m.LastUsed = time.Now()
+ m.totalLatency += float64(latency.Milliseconds())
+
+ if success {
+ m.successCount++
+ m.FailCount = 0
+ } else {
+ m.FailCount++
+ }
+
+ // Update derived metrics
+ if m.TotalRequests > 0 {
+ m.SuccessRate = float64(m.successCount) / float64(m.TotalRequests)
+ m.AvgLatency = m.totalLatency / float64(m.TotalRequests)
+ }
+}
+
+// SetQuotaRemaining updates the remaining quota for a token.
+func (s *TokenScorer) SetQuotaRemaining(tokenKey string, quota float64) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ m := s.getOrCreateMetrics(tokenKey)
+ m.QuotaRemaining = quota
+}
+
+// GetMetrics returns a copy of the metrics for a token.
+func (s *TokenScorer) GetMetrics(tokenKey string) *TokenMetrics {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if m, ok := s.metrics[tokenKey]; ok {
+ copy := *m
+ return ©
+ }
+ return nil
+}
+
+// CalculateScore computes the score for a token (higher is better).
+func (s *TokenScorer) CalculateScore(tokenKey string) float64 {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ m, ok := s.metrics[tokenKey]
+ if !ok {
+ return 1.0 // New tokens get a high initial score
+ }
+
+ // Success rate component (0-1)
+ successScore := m.SuccessRate
+
+ // Quota component (0-1)
+ quotaScore := m.QuotaRemaining
+
+ // Latency component (normalized, lower is better)
+ // Using exponential decay: score = e^(-latency/1000)
+ // 1000ms latency -> ~0.37 score, 100ms -> ~0.90 score
+ latencyScore := math.Exp(-m.AvgLatency / 1000.0)
+ if m.TotalRequests == 0 {
+ latencyScore = 1.0
+ }
+
+ // Last used component (prefer tokens not recently used)
+ // Score increases as time since last use increases
+ timeSinceUse := time.Since(m.LastUsed).Seconds()
+ // Normalize: 60 seconds -> ~0.63 score, 0 seconds -> 0 score
+ lastUsedScore := 1.0 - math.Exp(-timeSinceUse/60.0)
+ if m.LastUsed.IsZero() {
+ lastUsedScore = 1.0
+ }
+
+ // Calculate weighted score
+ score := s.successRateWeight*successScore +
+ s.quotaWeight*quotaScore +
+ s.latencyWeight*latencyScore +
+ s.lastUsedWeight*lastUsedScore
+
+ // Apply consecutive failure penalty
+ if m.FailCount > 0 {
+ penalty := s.failPenaltyMultiplier * float64(m.FailCount)
+ score = score * math.Max(0, 1.0-penalty)
+ }
+
+ return score
+}
+
+// SelectBestToken selects the token with the highest score.
+func (s *TokenScorer) SelectBestToken(tokens []string) string {
+ if len(tokens) == 0 {
+ return ""
+ }
+ if len(tokens) == 1 {
+ return tokens[0]
+ }
+
+ bestToken := tokens[0]
+ bestScore := s.CalculateScore(tokens[0])
+
+ for _, token := range tokens[1:] {
+ score := s.CalculateScore(token)
+ if score > bestScore {
+ bestScore = score
+ bestToken = token
+ }
+ }
+
+ return bestToken
+}
+
+// ResetMetrics clears all metrics for a token.
+func (s *TokenScorer) ResetMetrics(tokenKey string) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ delete(s.metrics, tokenKey)
+}
+
+// ResetAllMetrics clears all stored metrics.
+func (s *TokenScorer) ResetAllMetrics() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.metrics = make(map[string]*TokenMetrics)
+}
diff --git a/internal/auth/kiro/metrics_test.go b/internal/auth/kiro/metrics_test.go
new file mode 100644
index 00000000..ffe2a876
--- /dev/null
+++ b/internal/auth/kiro/metrics_test.go
@@ -0,0 +1,301 @@
+package kiro
+
+import (
+ "sync"
+ "testing"
+ "time"
+)
+
+func TestNewTokenScorer(t *testing.T) {
+ s := NewTokenScorer()
+ if s == nil {
+ t.Fatal("expected non-nil TokenScorer")
+ }
+ if s.metrics == nil {
+ t.Error("expected non-nil metrics map")
+ }
+ if s.successRateWeight != 0.4 {
+ t.Errorf("expected successRateWeight 0.4, got %f", s.successRateWeight)
+ }
+ if s.quotaWeight != 0.25 {
+ t.Errorf("expected quotaWeight 0.25, got %f", s.quotaWeight)
+ }
+}
+
+func TestRecordRequest_Success(t *testing.T) {
+ s := NewTokenScorer()
+ s.RecordRequest("token1", true, 100*time.Millisecond)
+
+ m := s.GetMetrics("token1")
+ if m == nil {
+ t.Fatal("expected non-nil metrics")
+ }
+ if m.TotalRequests != 1 {
+ t.Errorf("expected TotalRequests 1, got %d", m.TotalRequests)
+ }
+ if m.SuccessRate != 1.0 {
+ t.Errorf("expected SuccessRate 1.0, got %f", m.SuccessRate)
+ }
+ if m.FailCount != 0 {
+ t.Errorf("expected FailCount 0, got %d", m.FailCount)
+ }
+ if m.AvgLatency != 100 {
+ t.Errorf("expected AvgLatency 100, got %f", m.AvgLatency)
+ }
+}
+
+func TestRecordRequest_Failure(t *testing.T) {
+ s := NewTokenScorer()
+ s.RecordRequest("token1", false, 200*time.Millisecond)
+
+ m := s.GetMetrics("token1")
+ if m.SuccessRate != 0.0 {
+ t.Errorf("expected SuccessRate 0.0, got %f", m.SuccessRate)
+ }
+ if m.FailCount != 1 {
+ t.Errorf("expected FailCount 1, got %d", m.FailCount)
+ }
+}
+
+func TestRecordRequest_MixedResults(t *testing.T) {
+ s := NewTokenScorer()
+ s.RecordRequest("token1", true, 100*time.Millisecond)
+ s.RecordRequest("token1", true, 100*time.Millisecond)
+ s.RecordRequest("token1", false, 100*time.Millisecond)
+ s.RecordRequest("token1", true, 100*time.Millisecond)
+
+ m := s.GetMetrics("token1")
+ if m.TotalRequests != 4 {
+ t.Errorf("expected TotalRequests 4, got %d", m.TotalRequests)
+ }
+ if m.SuccessRate != 0.75 {
+ t.Errorf("expected SuccessRate 0.75, got %f", m.SuccessRate)
+ }
+ if m.FailCount != 0 {
+ t.Errorf("expected FailCount 0 (reset on success), got %d", m.FailCount)
+ }
+}
+
+func TestRecordRequest_ConsecutiveFailures(t *testing.T) {
+ s := NewTokenScorer()
+ s.RecordRequest("token1", true, 100*time.Millisecond)
+ s.RecordRequest("token1", false, 100*time.Millisecond)
+ s.RecordRequest("token1", false, 100*time.Millisecond)
+ s.RecordRequest("token1", false, 100*time.Millisecond)
+
+ m := s.GetMetrics("token1")
+ if m.FailCount != 3 {
+ t.Errorf("expected FailCount 3, got %d", m.FailCount)
+ }
+}
+
+func TestSetQuotaRemaining(t *testing.T) {
+ s := NewTokenScorer()
+ s.SetQuotaRemaining("token1", 0.5)
+
+ m := s.GetMetrics("token1")
+ if m.QuotaRemaining != 0.5 {
+ t.Errorf("expected QuotaRemaining 0.5, got %f", m.QuotaRemaining)
+ }
+}
+
+func TestGetMetrics_NonExistent(t *testing.T) {
+ s := NewTokenScorer()
+ m := s.GetMetrics("nonexistent")
+ if m != nil {
+ t.Error("expected nil metrics for non-existent token")
+ }
+}
+
+func TestGetMetrics_ReturnsCopy(t *testing.T) {
+ s := NewTokenScorer()
+ s.RecordRequest("token1", true, 100*time.Millisecond)
+
+ m1 := s.GetMetrics("token1")
+ m1.TotalRequests = 999
+
+ m2 := s.GetMetrics("token1")
+ if m2.TotalRequests == 999 {
+ t.Error("GetMetrics should return a copy")
+ }
+}
+
+func TestCalculateScore_NewToken(t *testing.T) {
+ s := NewTokenScorer()
+ score := s.CalculateScore("newtoken")
+ if score != 1.0 {
+ t.Errorf("expected score 1.0 for new token, got %f", score)
+ }
+}
+
+func TestCalculateScore_PerfectToken(t *testing.T) {
+ s := NewTokenScorer()
+ s.RecordRequest("token1", true, 50*time.Millisecond)
+ s.SetQuotaRemaining("token1", 1.0)
+
+ time.Sleep(100 * time.Millisecond)
+ score := s.CalculateScore("token1")
+ if score < 0.5 || score > 1.0 {
+ t.Errorf("expected high score for perfect token, got %f", score)
+ }
+}
+
+func TestCalculateScore_FailedToken(t *testing.T) {
+ s := NewTokenScorer()
+ for i := 0; i < 5; i++ {
+ s.RecordRequest("token1", false, 1000*time.Millisecond)
+ }
+ s.SetQuotaRemaining("token1", 0.1)
+
+ score := s.CalculateScore("token1")
+ if score > 0.5 {
+ t.Errorf("expected low score for failed token, got %f", score)
+ }
+}
+
+func TestCalculateScore_FailPenalty(t *testing.T) {
+ s := NewTokenScorer()
+ s.RecordRequest("token1", true, 100*time.Millisecond)
+ scoreNoFail := s.CalculateScore("token1")
+
+ s.RecordRequest("token1", false, 100*time.Millisecond)
+ s.RecordRequest("token1", false, 100*time.Millisecond)
+ scoreWithFail := s.CalculateScore("token1")
+
+ if scoreWithFail >= scoreNoFail {
+ t.Errorf("expected lower score with consecutive failures: noFail=%f, withFail=%f", scoreNoFail, scoreWithFail)
+ }
+}
+
+func TestSelectBestToken_Empty(t *testing.T) {
+ s := NewTokenScorer()
+ best := s.SelectBestToken([]string{})
+ if best != "" {
+ t.Errorf("expected empty string for empty tokens, got %s", best)
+ }
+}
+
+func TestSelectBestToken_SingleToken(t *testing.T) {
+ s := NewTokenScorer()
+ best := s.SelectBestToken([]string{"token1"})
+ if best != "token1" {
+ t.Errorf("expected token1, got %s", best)
+ }
+}
+
+func TestSelectBestToken_MultipleTokens(t *testing.T) {
+ s := NewTokenScorer()
+
+ s.RecordRequest("bad", false, 1000*time.Millisecond)
+ s.RecordRequest("bad", false, 1000*time.Millisecond)
+ s.SetQuotaRemaining("bad", 0.1)
+
+ s.RecordRequest("good", true, 50*time.Millisecond)
+ s.SetQuotaRemaining("good", 0.9)
+
+ time.Sleep(50 * time.Millisecond)
+
+ best := s.SelectBestToken([]string{"bad", "good"})
+ if best != "good" {
+ t.Errorf("expected good token to be selected, got %s", best)
+ }
+}
+
+func TestResetMetrics(t *testing.T) {
+ s := NewTokenScorer()
+ s.RecordRequest("token1", true, 100*time.Millisecond)
+ s.ResetMetrics("token1")
+
+ m := s.GetMetrics("token1")
+ if m != nil {
+ t.Error("expected nil metrics after reset")
+ }
+}
+
+func TestResetAllMetrics(t *testing.T) {
+ s := NewTokenScorer()
+ s.RecordRequest("token1", true, 100*time.Millisecond)
+ s.RecordRequest("token2", true, 100*time.Millisecond)
+ s.RecordRequest("token3", true, 100*time.Millisecond)
+
+ s.ResetAllMetrics()
+
+ if s.GetMetrics("token1") != nil {
+ t.Error("expected nil metrics for token1 after reset all")
+ }
+ if s.GetMetrics("token2") != nil {
+ t.Error("expected nil metrics for token2 after reset all")
+ }
+}
+
+func TestTokenScorer_ConcurrentAccess(t *testing.T) {
+ s := NewTokenScorer()
+ const numGoroutines = 50
+ const numOperations = 100
+
+ var wg sync.WaitGroup
+ wg.Add(numGoroutines)
+
+ for i := 0; i < numGoroutines; i++ {
+ go func(id int) {
+ defer wg.Done()
+ tokenKey := "token" + string(rune('a'+id%10))
+ for j := 0; j < numOperations; j++ {
+ switch j % 6 {
+ case 0:
+ s.RecordRequest(tokenKey, j%2 == 0, time.Duration(j)*time.Millisecond)
+ case 1:
+ s.SetQuotaRemaining(tokenKey, float64(j%100)/100)
+ case 2:
+ s.GetMetrics(tokenKey)
+ case 3:
+ s.CalculateScore(tokenKey)
+ case 4:
+ s.SelectBestToken([]string{tokenKey, "token_x", "token_y"})
+ case 5:
+ if j%20 == 0 {
+ s.ResetMetrics(tokenKey)
+ }
+ }
+ }
+ }(i)
+ }
+
+ wg.Wait()
+}
+
+func TestAvgLatencyCalculation(t *testing.T) {
+ s := NewTokenScorer()
+ s.RecordRequest("token1", true, 100*time.Millisecond)
+ s.RecordRequest("token1", true, 200*time.Millisecond)
+ s.RecordRequest("token1", true, 300*time.Millisecond)
+
+ m := s.GetMetrics("token1")
+ if m.AvgLatency != 200 {
+ t.Errorf("expected AvgLatency 200, got %f", m.AvgLatency)
+ }
+}
+
+func TestLastUsedUpdated(t *testing.T) {
+ s := NewTokenScorer()
+ before := time.Now()
+ s.RecordRequest("token1", true, 100*time.Millisecond)
+
+ m := s.GetMetrics("token1")
+ if m.LastUsed.Before(before) {
+ t.Error("expected LastUsed to be after test start time")
+ }
+ if m.LastUsed.After(time.Now()) {
+ t.Error("expected LastUsed to be before or equal to now")
+ }
+}
+
+func TestDefaultQuotaForNewToken(t *testing.T) {
+ s := NewTokenScorer()
+ s.RecordRequest("token1", true, 100*time.Millisecond)
+
+ m := s.GetMetrics("token1")
+ if m.QuotaRemaining != 1.0 {
+ t.Errorf("expected default QuotaRemaining 1.0, got %f", m.QuotaRemaining)
+ }
+}
diff --git a/internal/auth/kiro/oauth.go b/internal/auth/kiro/oauth.go
new file mode 100644
index 00000000..a286cf42
--- /dev/null
+++ b/internal/auth/kiro/oauth.go
@@ -0,0 +1,329 @@
+// Package kiro provides OAuth2 authentication for Kiro using native Google login.
+package kiro
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "html"
+ "io"
+ "net"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ // Kiro auth endpoint
+ kiroAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
+
+ // Default callback port
+ defaultCallbackPort = 9876
+
+ // Auth timeout
+ authTimeout = 10 * time.Minute
+)
+
+// KiroTokenResponse represents the response from Kiro token endpoint.
+type KiroTokenResponse struct {
+ AccessToken string `json:"accessToken"`
+ RefreshToken string `json:"refreshToken"`
+ ProfileArn string `json:"profileArn"`
+ ExpiresIn int `json:"expiresIn"`
+}
+
+// KiroOAuth handles the OAuth flow for Kiro authentication.
+type KiroOAuth struct {
+ httpClient *http.Client
+ cfg *config.Config
+}
+
+// NewKiroOAuth creates a new Kiro OAuth handler.
+func NewKiroOAuth(cfg *config.Config) *KiroOAuth {
+ client := &http.Client{Timeout: 30 * time.Second}
+ if cfg != nil {
+ client = util.SetProxy(&cfg.SDKConfig, client)
+ }
+ return &KiroOAuth{
+ httpClient: client,
+ cfg: cfg,
+ }
+}
+
+// generateCodeVerifier generates a random code verifier for PKCE.
+func generateCodeVerifier() (string, error) {
+ b := make([]byte, 32)
+ if _, err := rand.Read(b); err != nil {
+ return "", err
+ }
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+// generateCodeChallenge generates the code challenge from verifier.
+func generateCodeChallenge(verifier string) string {
+ h := sha256.Sum256([]byte(verifier))
+ return base64.RawURLEncoding.EncodeToString(h[:])
+}
+
+// generateState generates a random state parameter.
+func generateState() (string, error) {
+ b := make([]byte, 16)
+ if _, err := rand.Read(b); err != nil {
+ return "", err
+ }
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+// AuthResult contains the authorization code and state from callback.
+type AuthResult struct {
+ Code string
+ State string
+ Error string
+}
+
+// startCallbackServer starts a local HTTP server to receive the OAuth callback.
+func (o *KiroOAuth) startCallbackServer(ctx context.Context, expectedState string) (string, <-chan AuthResult, error) {
+ // Try to find an available port - use localhost like Kiro does
+ listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", defaultCallbackPort))
+ if err != nil {
+ // Try with dynamic port (RFC 8252 allows dynamic ports for native apps)
+ log.Warnf("kiro oauth: default port %d is busy, falling back to dynamic port", defaultCallbackPort)
+ listener, err = net.Listen("tcp", "localhost:0")
+ if err != nil {
+ return "", nil, fmt.Errorf("failed to start callback server: %w", err)
+ }
+ }
+
+ port := listener.Addr().(*net.TCPAddr).Port
+ // Use http scheme for local callback server
+ redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port)
+ resultChan := make(chan AuthResult, 1)
+
+ server := &http.Server{
+ ReadHeaderTimeout: 10 * time.Second,
+ }
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
+ code := r.URL.Query().Get("code")
+ state := r.URL.Query().Get("state")
+ errParam := r.URL.Query().Get("error")
+
+ if errParam != "" {
+ w.Header().Set("Content-Type", "text/html")
+ w.WriteHeader(http.StatusBadRequest)
+ fmt.Fprintf(w, `Login Failed
%s
You can close this window.
`, html.EscapeString(errParam))
+ resultChan <- AuthResult{Error: errParam}
+ return
+ }
+
+ if state != expectedState {
+ w.Header().Set("Content-Type", "text/html")
+ w.WriteHeader(http.StatusBadRequest)
+ fmt.Fprint(w, `Login Failed
Invalid state parameter
You can close this window.
`)
+ resultChan <- AuthResult{Error: "state mismatch"}
+ return
+ }
+
+ w.Header().Set("Content-Type", "text/html")
+ fmt.Fprint(w, `Login Successful!
You can close this window and return to the terminal.
`)
+ resultChan <- AuthResult{Code: code, State: state}
+ })
+
+ server.Handler = mux
+
+ go func() {
+ if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
+ log.Debugf("callback server error: %v", err)
+ }
+ }()
+
+ go func() {
+ select {
+ case <-ctx.Done():
+ case <-time.After(authTimeout):
+ case <-resultChan:
+ }
+ _ = server.Shutdown(context.Background())
+ }()
+
+ return redirectURI, resultChan, nil
+}
+
+// LoginWithBuilderID performs OAuth login with AWS Builder ID using device code flow.
+func (o *KiroOAuth) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) {
+ ssoClient := NewSSOOIDCClient(o.cfg)
+ return ssoClient.LoginWithBuilderID(ctx)
+}
+
+// LoginWithBuilderIDAuthCode performs OAuth login with AWS Builder ID using authorization code flow.
+// This provides a better UX than device code flow as it uses automatic browser callback.
+func (o *KiroOAuth) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTokenData, error) {
+ ssoClient := NewSSOOIDCClient(o.cfg)
+ return ssoClient.LoginWithBuilderIDAuthCode(ctx)
+}
+
+// exchangeCodeForToken exchanges the authorization code for tokens.
+func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier, redirectURI string) (*KiroTokenData, error) {
+ payload := map[string]string{
+ "code": code,
+ "code_verifier": codeVerifier,
+ "redirect_uri": redirectURI,
+ }
+
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request: %w", err)
+ }
+
+ tokenURL := kiroAuthEndpoint + "/oauth/token"
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body)))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api")
+
+ resp, err := o.httpClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("token request failed: %w", err)
+ }
+ defer resp.Body.Close()
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody))
+ return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode)
+ }
+
+ var tokenResp KiroTokenResponse
+ if err := json.Unmarshal(respBody, &tokenResp); err != nil {
+ return nil, fmt.Errorf("failed to parse token response: %w", err)
+ }
+
+ // Validate ExpiresIn - use default 1 hour if invalid
+ expiresIn := tokenResp.ExpiresIn
+ if expiresIn <= 0 {
+ expiresIn = 3600
+ }
+ expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
+
+ return &KiroTokenData{
+ AccessToken: tokenResp.AccessToken,
+ RefreshToken: tokenResp.RefreshToken,
+ ProfileArn: tokenResp.ProfileArn,
+ ExpiresAt: expiresAt.Format(time.RFC3339),
+ AuthMethod: "social",
+ Provider: "", // Caller should preserve original provider
+ Region: "us-east-1",
+ }, nil
+}
+
+// RefreshToken refreshes an expired access token.
+// Uses KiroIDE-style User-Agent to match official Kiro IDE behavior.
+func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) {
+ return o.RefreshTokenWithFingerprint(ctx, refreshToken, "")
+}
+
+// RefreshTokenWithFingerprint refreshes an expired access token with a specific fingerprint.
+// tokenKey is used to generate a consistent fingerprint for the token.
+func (o *KiroOAuth) RefreshTokenWithFingerprint(ctx context.Context, refreshToken, tokenKey string) (*KiroTokenData, error) {
+ payload := map[string]string{
+ "refreshToken": refreshToken,
+ }
+
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request: %w", err)
+ }
+
+ refreshURL := kiroAuthEndpoint + "/refreshToken"
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body)))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+
+ // Use KiroIDE-style User-Agent to match official Kiro IDE behavior
+ // This helps avoid 403 errors from server-side User-Agent validation
+ userAgent := buildKiroUserAgent(tokenKey)
+ req.Header.Set("User-Agent", userAgent)
+
+ resp, err := o.httpClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("refresh request failed: %w", err)
+ }
+ defer resp.Body.Close()
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
+ return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
+ }
+
+ var tokenResp KiroTokenResponse
+ if err := json.Unmarshal(respBody, &tokenResp); err != nil {
+ return nil, fmt.Errorf("failed to parse token response: %w", err)
+ }
+
+ // Validate ExpiresIn - use default 1 hour if invalid
+ expiresIn := tokenResp.ExpiresIn
+ if expiresIn <= 0 {
+ expiresIn = 3600
+ }
+ expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
+
+ return &KiroTokenData{
+ AccessToken: tokenResp.AccessToken,
+ RefreshToken: tokenResp.RefreshToken,
+ ProfileArn: tokenResp.ProfileArn,
+ ExpiresAt: expiresAt.Format(time.RFC3339),
+ AuthMethod: "social",
+ Provider: "", // Caller should preserve original provider
+ Region: "us-east-1",
+ }, nil
+}
+
+// buildKiroUserAgent builds a KiroIDE-style User-Agent string.
+// If tokenKey is provided, uses fingerprint manager for consistent fingerprint.
+// Otherwise generates a simple KiroIDE User-Agent.
+func buildKiroUserAgent(tokenKey string) string {
+ if tokenKey != "" {
+ fm := NewFingerprintManager()
+ fp := fm.GetFingerprint(tokenKey)
+ return fmt.Sprintf("KiroIDE-%s-%s", fp.KiroVersion, fp.KiroHash[:16])
+ }
+ // Default KiroIDE User-Agent matching kiro-openai-gateway format
+ return "KiroIDE-0.7.45-cli-proxy-api"
+}
+
+// LoginWithGoogle performs OAuth login with Google using Kiro's social auth.
+// This uses a custom protocol handler (kiro://) to receive the callback.
+func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
+ socialClient := NewSocialAuthClient(o.cfg)
+ return socialClient.LoginWithGoogle(ctx)
+}
+
+// LoginWithGitHub performs OAuth login with GitHub using Kiro's social auth.
+// This uses a custom protocol handler (kiro://) to receive the callback.
+func (o *KiroOAuth) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) {
+ socialClient := NewSocialAuthClient(o.cfg)
+ return socialClient.LoginWithGitHub(ctx)
+}
diff --git a/internal/auth/kiro/oauth_web.go b/internal/auth/kiro/oauth_web.go
new file mode 100644
index 00000000..88fba672
--- /dev/null
+++ b/internal/auth/kiro/oauth_web.go
@@ -0,0 +1,969 @@
+// Package kiro provides OAuth Web authentication for Kiro.
+package kiro
+
+import (
+ "context"
+ "crypto/rand"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "html/template"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ defaultSessionExpiry = 10 * time.Minute
+ pollIntervalSeconds = 5
+)
+
+type authSessionStatus string
+
+const (
+ statusPending authSessionStatus = "pending"
+ statusSuccess authSessionStatus = "success"
+ statusFailed authSessionStatus = "failed"
+)
+
+type webAuthSession struct {
+ stateID string
+ deviceCode string
+ userCode string
+ authURL string
+ verificationURI string
+ expiresIn int
+ interval int
+ status authSessionStatus
+ startedAt time.Time
+ completedAt time.Time
+ expiresAt time.Time
+ error string
+ tokenData *KiroTokenData
+ ssoClient *SSOOIDCClient
+ clientID string
+ clientSecret string
+ region string
+ cancelFunc context.CancelFunc
+ authMethod string // "google", "github", "builder-id", "idc"
+ startURL string // Used for IDC
+ codeVerifier string // Used for social auth PKCE
+ codeChallenge string // Used for social auth PKCE
+}
+
+type OAuthWebHandler struct {
+ cfg *config.Config
+ sessions map[string]*webAuthSession
+ mu sync.RWMutex
+ onTokenObtained func(*KiroTokenData)
+}
+
+func NewOAuthWebHandler(cfg *config.Config) *OAuthWebHandler {
+ return &OAuthWebHandler{
+ cfg: cfg,
+ sessions: make(map[string]*webAuthSession),
+ }
+}
+
+func (h *OAuthWebHandler) SetTokenCallback(callback func(*KiroTokenData)) {
+ h.onTokenObtained = callback
+}
+
+func (h *OAuthWebHandler) RegisterRoutes(router gin.IRouter) {
+ oauth := router.Group("/v0/oauth/kiro")
+ {
+ oauth.GET("", h.handleSelect)
+ oauth.GET("/start", h.handleStart)
+ oauth.GET("/callback", h.handleCallback)
+ oauth.GET("/social/callback", h.handleSocialCallback)
+ oauth.GET("/status", h.handleStatus)
+ oauth.POST("/import", h.handleImportToken)
+ oauth.POST("/refresh", h.handleManualRefresh)
+ }
+}
+
+func generateStateID() (string, error) {
+ b := make([]byte, 16)
+ if _, err := rand.Read(b); err != nil {
+ return "", err
+ }
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+func (h *OAuthWebHandler) handleSelect(c *gin.Context) {
+ h.renderSelectPage(c)
+}
+
+func (h *OAuthWebHandler) handleStart(c *gin.Context) {
+ method := c.Query("method")
+
+ if method == "" {
+ c.Redirect(http.StatusFound, "/v0/oauth/kiro")
+ return
+ }
+
+ switch method {
+ case "google", "github":
+ // Google/GitHub social login is not supported for third-party apps
+ // due to AWS Cognito redirect_uri restrictions
+ h.renderError(c, "Google/GitHub login is not available for third-party applications. Please use AWS Builder ID or import your token from Kiro IDE.")
+ case "builder-id":
+ h.startBuilderIDAuth(c)
+ case "idc":
+ h.startIDCAuth(c)
+ default:
+ h.renderError(c, fmt.Sprintf("Unknown authentication method: %s", method))
+ }
+}
+
+func (h *OAuthWebHandler) startSocialAuth(c *gin.Context, method string) {
+ stateID, err := generateStateID()
+ if err != nil {
+ h.renderError(c, "Failed to generate state parameter")
+ return
+ }
+
+ codeVerifier, codeChallenge, err := generatePKCE()
+ if err != nil {
+ h.renderError(c, "Failed to generate PKCE parameters")
+ return
+ }
+
+ socialClient := NewSocialAuthClient(h.cfg)
+
+ var provider string
+ if method == "google" {
+ provider = string(ProviderGoogle)
+ } else {
+ provider = string(ProviderGitHub)
+ }
+
+ redirectURI := h.getSocialCallbackURL(c)
+ authURL := socialClient.buildLoginURL(provider, redirectURI, codeChallenge, stateID)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
+
+ session := &webAuthSession{
+ stateID: stateID,
+ authMethod: method,
+ authURL: authURL,
+ status: statusPending,
+ startedAt: time.Now(),
+ expiresIn: 600,
+ codeVerifier: codeVerifier,
+ codeChallenge: codeChallenge,
+ region: "us-east-1",
+ cancelFunc: cancel,
+ }
+
+ h.mu.Lock()
+ h.sessions[stateID] = session
+ h.mu.Unlock()
+
+ go func() {
+ <-ctx.Done()
+ h.mu.Lock()
+ if session.status == statusPending {
+ session.status = statusFailed
+ session.error = "Authentication timed out"
+ }
+ h.mu.Unlock()
+ }()
+
+ c.Redirect(http.StatusFound, authURL)
+}
+
+func (h *OAuthWebHandler) getSocialCallbackURL(c *gin.Context) string {
+ scheme := "http"
+ if c.Request.TLS != nil || c.GetHeader("X-Forwarded-Proto") == "https" {
+ scheme = "https"
+ }
+ return fmt.Sprintf("%s://%s/v0/oauth/kiro/social/callback", scheme, c.Request.Host)
+}
+
+func (h *OAuthWebHandler) startBuilderIDAuth(c *gin.Context) {
+ stateID, err := generateStateID()
+ if err != nil {
+ h.renderError(c, "Failed to generate state parameter")
+ return
+ }
+
+ region := defaultIDCRegion
+ startURL := builderIDStartURL
+
+ ssoClient := NewSSOOIDCClient(h.cfg)
+
+ regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
+ if err != nil {
+ log.Errorf("OAuth Web: failed to register client: %v", err)
+ h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
+ return
+ }
+
+ authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
+ c.Request.Context(),
+ regResp.ClientID,
+ regResp.ClientSecret,
+ startURL,
+ region,
+ )
+ if err != nil {
+ log.Errorf("OAuth Web: failed to start device authorization: %v", err)
+ h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
+ return
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
+
+ session := &webAuthSession{
+ stateID: stateID,
+ deviceCode: authResp.DeviceCode,
+ userCode: authResp.UserCode,
+ authURL: authResp.VerificationURIComplete,
+ verificationURI: authResp.VerificationURI,
+ expiresIn: authResp.ExpiresIn,
+ interval: authResp.Interval,
+ status: statusPending,
+ startedAt: time.Now(),
+ ssoClient: ssoClient,
+ clientID: regResp.ClientID,
+ clientSecret: regResp.ClientSecret,
+ region: region,
+ authMethod: "builder-id",
+ startURL: startURL,
+ cancelFunc: cancel,
+ }
+
+ h.mu.Lock()
+ h.sessions[stateID] = session
+ h.mu.Unlock()
+
+ go h.pollForToken(ctx, session)
+
+ h.renderStartPage(c, session)
+}
+
+func (h *OAuthWebHandler) startIDCAuth(c *gin.Context) {
+ startURL := c.Query("startUrl")
+ region := c.Query("region")
+
+ if startURL == "" {
+ h.renderError(c, "Missing startUrl parameter for IDC authentication")
+ return
+ }
+ if region == "" {
+ region = defaultIDCRegion
+ }
+
+ stateID, err := generateStateID()
+ if err != nil {
+ h.renderError(c, "Failed to generate state parameter")
+ return
+ }
+
+ ssoClient := NewSSOOIDCClient(h.cfg)
+
+ regResp, err := ssoClient.RegisterClientWithRegion(c.Request.Context(), region)
+ if err != nil {
+ log.Errorf("OAuth Web: failed to register client: %v", err)
+ h.renderError(c, fmt.Sprintf("Failed to register client: %v", err))
+ return
+ }
+
+ authResp, err := ssoClient.StartDeviceAuthorizationWithIDC(
+ c.Request.Context(),
+ regResp.ClientID,
+ regResp.ClientSecret,
+ startURL,
+ region,
+ )
+ if err != nil {
+ log.Errorf("OAuth Web: failed to start device authorization: %v", err)
+ h.renderError(c, fmt.Sprintf("Failed to start device authorization: %v", err))
+ return
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Duration(authResp.ExpiresIn)*time.Second)
+
+ session := &webAuthSession{
+ stateID: stateID,
+ deviceCode: authResp.DeviceCode,
+ userCode: authResp.UserCode,
+ authURL: authResp.VerificationURIComplete,
+ verificationURI: authResp.VerificationURI,
+ expiresIn: authResp.ExpiresIn,
+ interval: authResp.Interval,
+ status: statusPending,
+ startedAt: time.Now(),
+ ssoClient: ssoClient,
+ clientID: regResp.ClientID,
+ clientSecret: regResp.ClientSecret,
+ region: region,
+ authMethod: "idc",
+ startURL: startURL,
+ cancelFunc: cancel,
+ }
+
+ h.mu.Lock()
+ h.sessions[stateID] = session
+ h.mu.Unlock()
+
+ go h.pollForToken(ctx, session)
+
+ h.renderStartPage(c, session)
+}
+
+func (h *OAuthWebHandler) pollForToken(ctx context.Context, session *webAuthSession) {
+ defer session.cancelFunc()
+
+ interval := time.Duration(session.interval) * time.Second
+ if interval < time.Duration(pollIntervalSeconds)*time.Second {
+ interval = time.Duration(pollIntervalSeconds) * time.Second
+ }
+
+ ticker := time.NewTicker(interval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ h.mu.Lock()
+ if session.status == statusPending {
+ session.status = statusFailed
+ session.error = "Authentication timed out"
+ }
+ h.mu.Unlock()
+ return
+ case <-ticker.C:
+ tokenResp, err := h.ssoClient(session).CreateTokenWithRegion(
+ ctx,
+ session.clientID,
+ session.clientSecret,
+ session.deviceCode,
+ session.region,
+ )
+
+ if err != nil {
+ errStr := err.Error()
+ if errStr == ErrAuthorizationPending.Error() {
+ continue
+ }
+ if errStr == ErrSlowDown.Error() {
+ interval += 5 * time.Second
+ ticker.Reset(interval)
+ continue
+ }
+
+ h.mu.Lock()
+ session.status = statusFailed
+ session.error = errStr
+ session.completedAt = time.Now()
+ h.mu.Unlock()
+
+ log.Errorf("OAuth Web: token polling failed: %v", err)
+ return
+ }
+
+ expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
+ profileArn := session.ssoClient.fetchProfileArn(ctx, tokenResp.AccessToken)
+ email := FetchUserEmailWithFallback(ctx, h.cfg, tokenResp.AccessToken)
+
+ tokenData := &KiroTokenData{
+ AccessToken: tokenResp.AccessToken,
+ RefreshToken: tokenResp.RefreshToken,
+ ProfileArn: profileArn,
+ ExpiresAt: expiresAt.Format(time.RFC3339),
+ AuthMethod: session.authMethod,
+ Provider: "AWS",
+ ClientID: session.clientID,
+ ClientSecret: session.clientSecret,
+ Email: email,
+ Region: session.region,
+ StartURL: session.startURL,
+ }
+
+ h.mu.Lock()
+ session.status = statusSuccess
+ session.completedAt = time.Now()
+ session.expiresAt = expiresAt
+ session.tokenData = tokenData
+ h.mu.Unlock()
+
+ if h.onTokenObtained != nil {
+ h.onTokenObtained(tokenData)
+ }
+
+ // Save token to file
+ h.saveTokenToFile(tokenData)
+
+ log.Infof("OAuth Web: authentication successful for %s", email)
+ return
+ }
+ }
+}
+
+// saveTokenToFile saves the token data to the auth directory
+func (h *OAuthWebHandler) saveTokenToFile(tokenData *KiroTokenData) {
+ // Get auth directory from config or use default
+ authDir := ""
+ if h.cfg != nil && h.cfg.AuthDir != "" {
+ var err error
+ authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
+ if err != nil {
+ log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
+ }
+ }
+
+ // Fall back to default location
+ if authDir == "" {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ log.Errorf("OAuth Web: failed to get home directory: %v", err)
+ return
+ }
+ authDir = filepath.Join(home, ".cli-proxy-api")
+ }
+
+ // Create directory if not exists
+ if err := os.MkdirAll(authDir, 0700); err != nil {
+ log.Errorf("OAuth Web: failed to create auth directory: %v", err)
+ return
+ }
+
+ // Generate filename using the unified function
+ fileName := GenerateTokenFileName(tokenData)
+
+ authFilePath := filepath.Join(authDir, fileName)
+
+ // Convert to storage format and save
+ storage := &KiroTokenStorage{
+ Type: "kiro",
+ AccessToken: tokenData.AccessToken,
+ RefreshToken: tokenData.RefreshToken,
+ ProfileArn: tokenData.ProfileArn,
+ ExpiresAt: tokenData.ExpiresAt,
+ AuthMethod: tokenData.AuthMethod,
+ Provider: tokenData.Provider,
+ LastRefresh: time.Now().Format(time.RFC3339),
+ ClientID: tokenData.ClientID,
+ ClientSecret: tokenData.ClientSecret,
+ Region: tokenData.Region,
+ StartURL: tokenData.StartURL,
+ Email: tokenData.Email,
+ }
+
+ if err := storage.SaveTokenToFile(authFilePath); err != nil {
+ log.Errorf("OAuth Web: failed to save token to file: %v", err)
+ return
+ }
+
+ log.Infof("OAuth Web: token saved to %s", authFilePath)
+}
+
+func (h *OAuthWebHandler) ssoClient(session *webAuthSession) *SSOOIDCClient {
+ return session.ssoClient
+}
+
+func (h *OAuthWebHandler) handleCallback(c *gin.Context) {
+ stateID := c.Query("state")
+ errParam := c.Query("error")
+
+ if errParam != "" {
+ h.renderError(c, errParam)
+ return
+ }
+
+ if stateID == "" {
+ h.renderError(c, "Missing state parameter")
+ return
+ }
+
+ h.mu.RLock()
+ session, exists := h.sessions[stateID]
+ h.mu.RUnlock()
+
+ if !exists {
+ h.renderError(c, "Invalid or expired session")
+ return
+ }
+
+ if session.status == statusSuccess {
+ h.renderSuccess(c, session)
+ } else if session.status == statusFailed {
+ h.renderError(c, session.error)
+ } else {
+ c.Redirect(http.StatusFound, "/v0/oauth/kiro/start")
+ }
+}
+
+func (h *OAuthWebHandler) handleSocialCallback(c *gin.Context) {
+ stateID := c.Query("state")
+ code := c.Query("code")
+ errParam := c.Query("error")
+
+ if errParam != "" {
+ h.renderError(c, errParam)
+ return
+ }
+
+ if stateID == "" {
+ h.renderError(c, "Missing state parameter")
+ return
+ }
+
+ if code == "" {
+ h.renderError(c, "Missing authorization code")
+ return
+ }
+
+ h.mu.RLock()
+ session, exists := h.sessions[stateID]
+ h.mu.RUnlock()
+
+ if !exists {
+ h.renderError(c, "Invalid or expired session")
+ return
+ }
+
+ if session.authMethod != "google" && session.authMethod != "github" {
+ h.renderError(c, "Invalid session type for social callback")
+ return
+ }
+
+ socialClient := NewSocialAuthClient(h.cfg)
+ redirectURI := h.getSocialCallbackURL(c)
+
+ tokenReq := &CreateTokenRequest{
+ Code: code,
+ CodeVerifier: session.codeVerifier,
+ RedirectURI: redirectURI,
+ }
+
+ tokenResp, err := socialClient.CreateToken(c.Request.Context(), tokenReq)
+ if err != nil {
+ log.Errorf("OAuth Web: social token exchange failed: %v", err)
+ h.mu.Lock()
+ session.status = statusFailed
+ session.error = fmt.Sprintf("Token exchange failed: %v", err)
+ session.completedAt = time.Now()
+ h.mu.Unlock()
+ h.renderError(c, session.error)
+ return
+ }
+
+ expiresIn := tokenResp.ExpiresIn
+ if expiresIn <= 0 {
+ expiresIn = 3600
+ }
+ expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
+
+ email := ExtractEmailFromJWT(tokenResp.AccessToken)
+
+ var provider string
+ if session.authMethod == "google" {
+ provider = string(ProviderGoogle)
+ } else {
+ provider = string(ProviderGitHub)
+ }
+
+ tokenData := &KiroTokenData{
+ AccessToken: tokenResp.AccessToken,
+ RefreshToken: tokenResp.RefreshToken,
+ ProfileArn: tokenResp.ProfileArn,
+ ExpiresAt: expiresAt.Format(time.RFC3339),
+ AuthMethod: session.authMethod,
+ Provider: provider,
+ Email: email,
+ Region: "us-east-1",
+ }
+
+ h.mu.Lock()
+ session.status = statusSuccess
+ session.completedAt = time.Now()
+ session.expiresAt = expiresAt
+ session.tokenData = tokenData
+ h.mu.Unlock()
+
+ if session.cancelFunc != nil {
+ session.cancelFunc()
+ }
+
+ if h.onTokenObtained != nil {
+ h.onTokenObtained(tokenData)
+ }
+
+ // Save token to file
+ h.saveTokenToFile(tokenData)
+
+ log.Infof("OAuth Web: social authentication successful for %s via %s", email, provider)
+ h.renderSuccess(c, session)
+}
+
+func (h *OAuthWebHandler) handleStatus(c *gin.Context) {
+ stateID := c.Query("state")
+ if stateID == "" {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "missing state parameter"})
+ return
+ }
+
+ h.mu.RLock()
+ session, exists := h.sessions[stateID]
+ h.mu.RUnlock()
+
+ if !exists {
+ c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
+ return
+ }
+
+ response := gin.H{
+ "status": string(session.status),
+ }
+
+ switch session.status {
+ case statusPending:
+ elapsed := time.Since(session.startedAt).Seconds()
+ remaining := float64(session.expiresIn) - elapsed
+ if remaining < 0 {
+ remaining = 0
+ }
+ response["remaining_seconds"] = int(remaining)
+ case statusSuccess:
+ response["completed_at"] = session.completedAt.Format(time.RFC3339)
+ response["expires_at"] = session.expiresAt.Format(time.RFC3339)
+ case statusFailed:
+ response["error"] = session.error
+ response["failed_at"] = session.completedAt.Format(time.RFC3339)
+ }
+
+ c.JSON(http.StatusOK, response)
+}
+
+func (h *OAuthWebHandler) renderStartPage(c *gin.Context, session *webAuthSession) {
+ tmpl, err := template.New("start").Parse(oauthWebStartPageHTML)
+ if err != nil {
+ log.Errorf("OAuth Web: failed to parse template: %v", err)
+ c.String(http.StatusInternalServerError, "Template error")
+ return
+ }
+
+ data := map[string]interface{}{
+ "AuthURL": session.authURL,
+ "UserCode": session.userCode,
+ "ExpiresIn": session.expiresIn,
+ "StateID": session.stateID,
+ }
+
+ c.Header("Content-Type", "text/html; charset=utf-8")
+ if err := tmpl.Execute(c.Writer, data); err != nil {
+ log.Errorf("OAuth Web: failed to render template: %v", err)
+ }
+}
+
+func (h *OAuthWebHandler) renderSelectPage(c *gin.Context) {
+ tmpl, err := template.New("select").Parse(oauthWebSelectPageHTML)
+ if err != nil {
+ log.Errorf("OAuth Web: failed to parse select template: %v", err)
+ c.String(http.StatusInternalServerError, "Template error")
+ return
+ }
+
+ c.Header("Content-Type", "text/html; charset=utf-8")
+ if err := tmpl.Execute(c.Writer, nil); err != nil {
+ log.Errorf("OAuth Web: failed to render select template: %v", err)
+ }
+}
+
+func (h *OAuthWebHandler) renderError(c *gin.Context, errMsg string) {
+ tmpl, err := template.New("error").Parse(oauthWebErrorPageHTML)
+ if err != nil {
+ log.Errorf("OAuth Web: failed to parse error template: %v", err)
+ c.String(http.StatusInternalServerError, "Template error")
+ return
+ }
+
+ data := map[string]interface{}{
+ "Error": errMsg,
+ }
+
+ c.Header("Content-Type", "text/html; charset=utf-8")
+ c.Status(http.StatusBadRequest)
+ if err := tmpl.Execute(c.Writer, data); err != nil {
+ log.Errorf("OAuth Web: failed to render error template: %v", err)
+ }
+}
+
+func (h *OAuthWebHandler) renderSuccess(c *gin.Context, session *webAuthSession) {
+ tmpl, err := template.New("success").Parse(oauthWebSuccessPageHTML)
+ if err != nil {
+ log.Errorf("OAuth Web: failed to parse success template: %v", err)
+ c.String(http.StatusInternalServerError, "Template error")
+ return
+ }
+
+ data := map[string]interface{}{
+ "ExpiresAt": session.expiresAt.Format(time.RFC3339),
+ }
+
+ c.Header("Content-Type", "text/html; charset=utf-8")
+ if err := tmpl.Execute(c.Writer, data); err != nil {
+ log.Errorf("OAuth Web: failed to render success template: %v", err)
+ }
+}
+
+func (h *OAuthWebHandler) CleanupExpiredSessions() {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+
+ now := time.Now()
+ for id, session := range h.sessions {
+ if session.status != statusPending && now.Sub(session.completedAt) > 30*time.Minute {
+ delete(h.sessions, id)
+ } else if session.status == statusPending && now.Sub(session.startedAt) > defaultSessionExpiry {
+ session.cancelFunc()
+ delete(h.sessions, id)
+ }
+ }
+}
+
+func (h *OAuthWebHandler) GetSession(stateID string) (*webAuthSession, bool) {
+ h.mu.RLock()
+ defer h.mu.RUnlock()
+ session, exists := h.sessions[stateID]
+ return session, exists
+}
+
+// ImportTokenRequest represents the request body for token import
+type ImportTokenRequest struct {
+ RefreshToken string `json:"refreshToken"`
+}
+
+// handleImportToken handles manual refresh token import from Kiro IDE
+func (h *OAuthWebHandler) handleImportToken(c *gin.Context) {
+ var req ImportTokenRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "success": false,
+ "error": "Invalid request body",
+ })
+ return
+ }
+
+ refreshToken := strings.TrimSpace(req.RefreshToken)
+ if refreshToken == "" {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "success": false,
+ "error": "Refresh token is required",
+ })
+ return
+ }
+
+ // Validate token format
+ if !strings.HasPrefix(refreshToken, "aorAAAAAG") {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "success": false,
+ "error": "Invalid token format. Token should start with aorAAAAAG...",
+ })
+ return
+ }
+
+ // Create social auth client to refresh and validate the token
+ socialClient := NewSocialAuthClient(h.cfg)
+
+ // Refresh the token to validate it and get access token
+ tokenData, err := socialClient.RefreshSocialToken(c.Request.Context(), refreshToken)
+ if err != nil {
+ log.Errorf("OAuth Web: token refresh failed during import: %v", err)
+ c.JSON(http.StatusBadRequest, gin.H{
+ "success": false,
+ "error": fmt.Sprintf("Token validation failed: %v", err),
+ })
+ return
+ }
+
+ // Set the original refresh token (the refreshed one might be empty)
+ if tokenData.RefreshToken == "" {
+ tokenData.RefreshToken = refreshToken
+ }
+ tokenData.AuthMethod = "social"
+ tokenData.Provider = "imported"
+
+ // Notify callback if set
+ if h.onTokenObtained != nil {
+ h.onTokenObtained(tokenData)
+ }
+
+ // Save token to file
+ h.saveTokenToFile(tokenData)
+
+ // Generate filename for response using the unified function
+ fileName := GenerateTokenFileName(tokenData)
+
+ log.Infof("OAuth Web: token imported successfully")
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "Token imported successfully",
+ "fileName": fileName,
+ })
+}
+
+// handleManualRefresh handles manual token refresh requests from the web UI.
+// This allows users to trigger a token refresh when needed, without waiting
+// for the automatic 30-second check and 20-minute-before-expiry refresh cycle.
+// Uses the same refresh logic as kiro_executor.Refresh for consistency.
+func (h *OAuthWebHandler) handleManualRefresh(c *gin.Context) {
+ authDir := ""
+ if h.cfg != nil && h.cfg.AuthDir != "" {
+ var err error
+ authDir, err = util.ResolveAuthDir(h.cfg.AuthDir)
+ if err != nil {
+ log.Errorf("OAuth Web: failed to resolve auth directory: %v", err)
+ }
+ }
+
+ if authDir == "" {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "success": false,
+ "error": "Failed to get home directory",
+ })
+ return
+ }
+ authDir = filepath.Join(home, ".cli-proxy-api")
+ }
+
+ // Find all kiro token files in the auth directory
+ files, err := os.ReadDir(authDir)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{
+ "success": false,
+ "error": fmt.Sprintf("Failed to read auth directory: %v", err),
+ })
+ return
+ }
+
+ var refreshedCount int
+ var errors []string
+
+ for _, file := range files {
+ if file.IsDir() {
+ continue
+ }
+ name := file.Name()
+ if !strings.HasPrefix(name, "kiro-") || !strings.HasSuffix(name, ".json") {
+ continue
+ }
+
+ filePath := filepath.Join(authDir, name)
+ data, err := os.ReadFile(filePath)
+ if err != nil {
+ errors = append(errors, fmt.Sprintf("%s: read error - %v", name, err))
+ continue
+ }
+
+ var storage KiroTokenStorage
+ if err := json.Unmarshal(data, &storage); err != nil {
+ errors = append(errors, fmt.Sprintf("%s: parse error - %v", name, err))
+ continue
+ }
+
+ if storage.RefreshToken == "" {
+ errors = append(errors, fmt.Sprintf("%s: no refresh token", name))
+ continue
+ }
+
+ // Refresh token using the same logic as kiro_executor.Refresh
+ tokenData, err := h.refreshTokenData(c.Request.Context(), &storage)
+ if err != nil {
+ errors = append(errors, fmt.Sprintf("%s: refresh failed - %v", name, err))
+ continue
+ }
+
+ // Update storage with new token data
+ storage.AccessToken = tokenData.AccessToken
+ if tokenData.RefreshToken != "" {
+ storage.RefreshToken = tokenData.RefreshToken
+ }
+ storage.ExpiresAt = tokenData.ExpiresAt
+ storage.LastRefresh = time.Now().Format(time.RFC3339)
+ if tokenData.ProfileArn != "" {
+ storage.ProfileArn = tokenData.ProfileArn
+ }
+
+ // Write updated token back to file
+ updatedData, err := json.MarshalIndent(storage, "", " ")
+ if err != nil {
+ errors = append(errors, fmt.Sprintf("%s: marshal error - %v", name, err))
+ continue
+ }
+
+ tmpFile := filePath + ".tmp"
+ if err := os.WriteFile(tmpFile, updatedData, 0600); err != nil {
+ errors = append(errors, fmt.Sprintf("%s: write error - %v", name, err))
+ continue
+ }
+ if err := os.Rename(tmpFile, filePath); err != nil {
+ errors = append(errors, fmt.Sprintf("%s: rename error - %v", name, err))
+ continue
+ }
+
+ log.Infof("OAuth Web: manually refreshed token in %s, expires at %s", name, tokenData.ExpiresAt)
+ refreshedCount++
+
+ // Notify callback if set
+ if h.onTokenObtained != nil {
+ h.onTokenObtained(tokenData)
+ }
+ }
+
+ if refreshedCount == 0 && len(errors) > 0 {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "success": false,
+ "error": fmt.Sprintf("All refresh attempts failed: %v", errors),
+ })
+ return
+ }
+
+ response := gin.H{
+ "success": true,
+ "message": fmt.Sprintf("Refreshed %d token(s)", refreshedCount),
+ "refreshedCount": refreshedCount,
+ }
+ if len(errors) > 0 {
+ response["warnings"] = errors
+ }
+
+ c.JSON(http.StatusOK, response)
+}
+
+// refreshTokenData refreshes a token using the appropriate method based on auth type.
+// This mirrors the logic in kiro_executor.Refresh for consistency.
+func (h *OAuthWebHandler) refreshTokenData(ctx context.Context, storage *KiroTokenStorage) (*KiroTokenData, error) {
+ ssoClient := NewSSOOIDCClient(h.cfg)
+
+ switch {
+ case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "idc" && storage.Region != "":
+ // IDC refresh with region-specific endpoint
+ log.Debugf("OAuth Web: using SSO OIDC refresh for IDC (region=%s)", storage.Region)
+ return ssoClient.RefreshTokenWithRegion(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken, storage.Region, storage.StartURL)
+
+ case storage.ClientID != "" && storage.ClientSecret != "" && storage.AuthMethod == "builder-id":
+ // Builder ID refresh with default endpoint
+ log.Debugf("OAuth Web: using SSO OIDC refresh for AWS Builder ID")
+ return ssoClient.RefreshToken(ctx, storage.ClientID, storage.ClientSecret, storage.RefreshToken)
+
+ default:
+ // Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub)
+ log.Debugf("OAuth Web: using Kiro OAuth refresh endpoint")
+ oauth := NewKiroOAuth(h.cfg)
+ return oauth.RefreshToken(ctx, storage.RefreshToken)
+ }
+}
diff --git a/internal/auth/kiro/oauth_web_templates.go b/internal/auth/kiro/oauth_web_templates.go
new file mode 100644
index 00000000..228677a5
--- /dev/null
+++ b/internal/auth/kiro/oauth_web_templates.go
@@ -0,0 +1,779 @@
+// Package kiro provides OAuth Web authentication templates.
+package kiro
+
+const (
+ oauthWebStartPageHTML = `
+
+
+
+
+ AWS SSO Authentication
+
+
+
+
+
🔐 AWS SSO Authentication
+
Follow the steps below to complete authentication
+
+
+
+
+
+ 2
+ Enter the verification code below
+
+
+
Verification Code
+
{{.UserCode}}
+
+
+
+
+
+ 3
+ Complete AWS SSO login
+
+
+ Use your AWS SSO account to login and authorize
+
+
+
+
+
+
{{.ExpiresIn}}s
+
+ Waiting for authorization...
+
+
+
+
+ 💡 Tip: The authorization page will open in a new tab. This page will automatically update once authorization is complete.
+
+
+
+
+
+`
+
+ oauthWebErrorPageHTML = `
+
+
+
+
+ Authentication Failed
+
+
+
+
+
❌ Authentication Failed
+
+
🔄 Retry
+
+
+`
+
+ oauthWebSuccessPageHTML = `
+
+
+
+
+ Authentication Successful
+
+
+
+
+
✅
+
Authentication Successful!
+
+
You can close this window.
+
+
Token expires: {{.ExpiresAt}}
+
+
+`
+
+ oauthWebSelectPageHTML = `
+
+
+
+
+ Select Authentication Method
+
+
+
+
+
🔐 Select Authentication Method
+
Choose how you want to authenticate with Kiro
+
+
+
+
+
+
+
+
+ ⚠️ Note: Google and GitHub login are not available for third-party applications due to AWS Cognito restrictions. Please use AWS Builder ID or import your token from Kiro IDE.
+
+
+
+ 💡 How to get RefreshToken:
+ 1. Open Kiro IDE and login with Google/GitHub
+ 2. Find the token file: ~/.kiro/kiro-auth-token.json
+ 3. Copy the refreshToken value and paste it above
+
+
+
+
+
+`
+)
diff --git a/internal/auth/kiro/protocol_handler.go b/internal/auth/kiro/protocol_handler.go
new file mode 100644
index 00000000..d900ee33
--- /dev/null
+++ b/internal/auth/kiro/protocol_handler.go
@@ -0,0 +1,725 @@
+// Package kiro provides custom protocol handler registration for Kiro OAuth.
+// This enables the CLI to intercept kiro:// URIs for social authentication (Google/GitHub).
+package kiro
+
+import (
+ "context"
+ "fmt"
+ "html"
+ "net"
+ "net/http"
+ "net/url"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ // KiroProtocol is the custom URI scheme used by Kiro
+ KiroProtocol = "kiro"
+
+ // KiroAuthority is the URI authority for authentication callbacks
+ KiroAuthority = "kiro.kiroAgent"
+
+ // KiroAuthPath is the path for successful authentication
+ KiroAuthPath = "/authenticate-success"
+
+ // KiroRedirectURI is the full redirect URI for social auth
+ KiroRedirectURI = "kiro://kiro.kiroAgent/authenticate-success"
+
+ // DefaultHandlerPort is the default port for the local callback server
+ DefaultHandlerPort = 19876
+
+ // HandlerTimeout is how long to wait for the OAuth callback
+ HandlerTimeout = 10 * time.Minute
+)
+
+// ProtocolHandler manages the custom kiro:// protocol handler for OAuth callbacks.
+type ProtocolHandler struct {
+ port int
+ server *http.Server
+ listener net.Listener
+ resultChan chan *AuthCallback
+ stopChan chan struct{}
+ mu sync.Mutex
+ running bool
+}
+
+// AuthCallback contains the OAuth callback parameters.
+type AuthCallback struct {
+ Code string
+ State string
+ Error string
+}
+
+// NewProtocolHandler creates a new protocol handler.
+func NewProtocolHandler() *ProtocolHandler {
+ return &ProtocolHandler{
+ port: DefaultHandlerPort,
+ resultChan: make(chan *AuthCallback, 1),
+ stopChan: make(chan struct{}),
+ }
+}
+
+// Start starts the local callback server that receives redirects from the protocol handler.
+func (h *ProtocolHandler) Start(ctx context.Context) (int, error) {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+
+ if h.running {
+ return h.port, nil
+ }
+
+ // Drain any stale results from previous runs
+ select {
+ case <-h.resultChan:
+ default:
+ }
+
+ // Reset stopChan for reuse - close old channel first to unblock any waiting goroutines
+ if h.stopChan != nil {
+ select {
+ case <-h.stopChan:
+ // Already closed
+ default:
+ close(h.stopChan)
+ }
+ }
+ h.stopChan = make(chan struct{})
+
+ // Try ports in known range (must match handler script port range)
+ var listener net.Listener
+ var err error
+ portRange := []int{DefaultHandlerPort, DefaultHandlerPort + 1, DefaultHandlerPort + 2, DefaultHandlerPort + 3, DefaultHandlerPort + 4}
+
+ for _, port := range portRange {
+ listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
+ if err == nil {
+ break
+ }
+ log.Debugf("kiro protocol handler: port %d busy, trying next", port)
+ }
+
+ if listener == nil {
+ return 0, fmt.Errorf("failed to start callback server: all ports %d-%d are busy", DefaultHandlerPort, DefaultHandlerPort+4)
+ }
+
+ h.listener = listener
+ h.port = listener.Addr().(*net.TCPAddr).Port
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("/oauth/callback", h.handleCallback)
+
+ h.server = &http.Server{
+ Handler: mux,
+ ReadHeaderTimeout: 10 * time.Second,
+ }
+
+ go func() {
+ if err := h.server.Serve(listener); err != nil && err != http.ErrServerClosed {
+ log.Debugf("kiro protocol handler server error: %v", err)
+ }
+ }()
+
+ h.running = true
+ log.Debugf("kiro protocol handler started on port %d", h.port)
+
+ // Auto-shutdown after context done, timeout, or explicit stop
+ // Capture references to prevent race with new Start() calls
+ currentStopChan := h.stopChan
+ currentServer := h.server
+ currentListener := h.listener
+ go func() {
+ select {
+ case <-ctx.Done():
+ case <-time.After(HandlerTimeout):
+ case <-currentStopChan:
+ return // Already stopped, exit goroutine
+ }
+ // Only stop if this is still the current server/listener instance
+ h.mu.Lock()
+ if h.server == currentServer && h.listener == currentListener {
+ h.mu.Unlock()
+ h.Stop()
+ } else {
+ h.mu.Unlock()
+ }
+ }()
+
+ return h.port, nil
+}
+
+// Stop stops the callback server.
+func (h *ProtocolHandler) Stop() {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+
+ if !h.running {
+ return
+ }
+
+ // Signal the auto-shutdown goroutine to exit.
+ // This select pattern is safe because stopChan is only modified while holding h.mu,
+ // and we hold the lock here. The select prevents panic from double-close.
+ select {
+ case <-h.stopChan:
+ // Already closed
+ default:
+ close(h.stopChan)
+ }
+
+ if h.server != nil {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ _ = h.server.Shutdown(ctx)
+ }
+
+ h.running = false
+ log.Debug("kiro protocol handler stopped")
+}
+
+// WaitForCallback waits for the OAuth callback and returns the result.
+func (h *ProtocolHandler) WaitForCallback(ctx context.Context) (*AuthCallback, error) {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-time.After(HandlerTimeout):
+ return nil, fmt.Errorf("timeout waiting for OAuth callback")
+ case result := <-h.resultChan:
+ return result, nil
+ }
+}
+
+// GetPort returns the port the handler is listening on.
+func (h *ProtocolHandler) GetPort() int {
+ return h.port
+}
+
+// handleCallback processes the OAuth callback from the protocol handler script.
+func (h *ProtocolHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
+ code := r.URL.Query().Get("code")
+ state := r.URL.Query().Get("state")
+ errParam := r.URL.Query().Get("error")
+
+ result := &AuthCallback{
+ Code: code,
+ State: state,
+ Error: errParam,
+ }
+
+ // Send result
+ select {
+ case h.resultChan <- result:
+ default:
+ // Channel full, ignore duplicate callbacks
+ }
+
+ // Send success response
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ if errParam != "" {
+ w.WriteHeader(http.StatusBadRequest)
+ fmt.Fprintf(w, `
+
+Login Failed
+
+Login Failed
+Error: %s
+You can close this window.
+
+`, html.EscapeString(errParam))
+ } else {
+ fmt.Fprint(w, `
+
+Login Successful
+
+Login Successful!
+You can close this window and return to the terminal.
+
+
+`)
+ }
+}
+
+// IsProtocolHandlerInstalled checks if the kiro:// protocol handler is installed.
+func IsProtocolHandlerInstalled() bool {
+ switch runtime.GOOS {
+ case "linux":
+ return isLinuxHandlerInstalled()
+ case "windows":
+ return isWindowsHandlerInstalled()
+ case "darwin":
+ return isDarwinHandlerInstalled()
+ default:
+ return false
+ }
+}
+
+// InstallProtocolHandler installs the kiro:// protocol handler for the current platform.
+func InstallProtocolHandler(handlerPort int) error {
+ switch runtime.GOOS {
+ case "linux":
+ return installLinuxHandler(handlerPort)
+ case "windows":
+ return installWindowsHandler(handlerPort)
+ case "darwin":
+ return installDarwinHandler(handlerPort)
+ default:
+ return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
+ }
+}
+
+// UninstallProtocolHandler removes the kiro:// protocol handler.
+func UninstallProtocolHandler() error {
+ switch runtime.GOOS {
+ case "linux":
+ return uninstallLinuxHandler()
+ case "windows":
+ return uninstallWindowsHandler()
+ case "darwin":
+ return uninstallDarwinHandler()
+ default:
+ return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
+ }
+}
+
+// --- Linux Implementation ---
+
+func getLinuxDesktopPath() string {
+ homeDir, _ := os.UserHomeDir()
+ return filepath.Join(homeDir, ".local", "share", "applications", "kiro-oauth-handler.desktop")
+}
+
+func getLinuxHandlerScriptPath() string {
+ homeDir, _ := os.UserHomeDir()
+ return filepath.Join(homeDir, ".local", "bin", "kiro-oauth-handler")
+}
+
+func isLinuxHandlerInstalled() bool {
+ desktopPath := getLinuxDesktopPath()
+ _, err := os.Stat(desktopPath)
+ return err == nil
+}
+
+func installLinuxHandler(handlerPort int) error {
+ // Create directories
+ homeDir, err := os.UserHomeDir()
+ if err != nil {
+ return err
+ }
+
+ binDir := filepath.Join(homeDir, ".local", "bin")
+ appDir := filepath.Join(homeDir, ".local", "share", "applications")
+
+ if err := os.MkdirAll(binDir, 0755); err != nil {
+ return fmt.Errorf("failed to create bin directory: %w", err)
+ }
+ if err := os.MkdirAll(appDir, 0755); err != nil {
+ return fmt.Errorf("failed to create applications directory: %w", err)
+ }
+
+ // Create handler script - tries multiple ports to handle dynamic port allocation
+ scriptPath := getLinuxHandlerScriptPath()
+ scriptContent := fmt.Sprintf(`#!/bin/bash
+# Kiro OAuth Protocol Handler
+# Handles kiro:// URIs - tries CLI first, then forwards to Kiro IDE
+
+URL="$1"
+
+# Check curl availability
+if ! command -v curl &> /dev/null; then
+ echo "Error: curl is required for Kiro OAuth handler" >&2
+ exit 1
+fi
+
+# Extract code and state from URL
+[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}"
+[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}"
+[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}"
+
+# Try CLI proxy on multiple possible ports (default + dynamic range)
+CLI_OK=0
+for PORT in %d %d %d %d %d; do
+ if [ -n "$ERROR" ]; then
+ curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && CLI_OK=1 && break
+ elif [ -n "$CODE" ] && [ -n "$STATE" ]; then
+ curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && CLI_OK=1 && break
+ fi
+done
+
+# If CLI not available, forward to Kiro IDE
+if [ $CLI_OK -eq 0 ] && [ -x "/usr/share/kiro/kiro" ]; then
+ /usr/share/kiro/kiro --open-url "$URL" &
+fi
+`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
+
+ if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
+ return fmt.Errorf("failed to write handler script: %w", err)
+ }
+
+ // Create .desktop file
+ desktopPath := getLinuxDesktopPath()
+ desktopContent := fmt.Sprintf(`[Desktop Entry]
+Name=Kiro OAuth Handler
+Comment=Handle kiro:// protocol for CLI Proxy API authentication
+Exec=%s %%u
+Type=Application
+Terminal=false
+NoDisplay=true
+MimeType=x-scheme-handler/kiro;
+Categories=Utility;
+`, scriptPath)
+
+ if err := os.WriteFile(desktopPath, []byte(desktopContent), 0644); err != nil {
+ return fmt.Errorf("failed to write desktop file: %w", err)
+ }
+
+ // Register handler with xdg-mime
+ cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
+ if err := cmd.Run(); err != nil {
+ log.Warnf("xdg-mime registration failed (may need manual setup): %v", err)
+ }
+
+ // Update desktop database
+ cmd = exec.Command("update-desktop-database", appDir)
+ _ = cmd.Run() // Ignore errors, not critical
+
+ log.Info("Kiro protocol handler installed for Linux")
+ return nil
+}
+
+func uninstallLinuxHandler() error {
+ desktopPath := getLinuxDesktopPath()
+ scriptPath := getLinuxHandlerScriptPath()
+
+ if err := os.Remove(desktopPath); err != nil && !os.IsNotExist(err) {
+ return fmt.Errorf("failed to remove desktop file: %w", err)
+ }
+ if err := os.Remove(scriptPath); err != nil && !os.IsNotExist(err) {
+ return fmt.Errorf("failed to remove handler script: %w", err)
+ }
+
+ log.Info("Kiro protocol handler uninstalled")
+ return nil
+}
+
+// --- Windows Implementation ---
+
+func isWindowsHandlerInstalled() bool {
+ // Check registry key existence
+ cmd := exec.Command("reg", "query", `HKCU\Software\Classes\kiro`, "/ve")
+ return cmd.Run() == nil
+}
+
+func installWindowsHandler(handlerPort int) error {
+ homeDir, err := os.UserHomeDir()
+ if err != nil {
+ return err
+ }
+
+ // Create handler script (PowerShell)
+ scriptDir := filepath.Join(homeDir, ".cliproxyapi")
+ if err := os.MkdirAll(scriptDir, 0755); err != nil {
+ return fmt.Errorf("failed to create script directory: %w", err)
+ }
+
+ scriptPath := filepath.Join(scriptDir, "kiro-oauth-handler.ps1")
+ scriptContent := fmt.Sprintf(`# Kiro OAuth Protocol Handler for Windows
+param([string]$url)
+
+# Load required assembly for HttpUtility
+Add-Type -AssemblyName System.Web
+
+# Parse URL parameters
+$uri = [System.Uri]$url
+$query = [System.Web.HttpUtility]::ParseQueryString($uri.Query)
+$code = $query["code"]
+$state = $query["state"]
+$errorParam = $query["error"]
+
+# Try multiple ports (default + dynamic range)
+$ports = @(%d, %d, %d, %d, %d)
+$success = $false
+
+foreach ($port in $ports) {
+ if ($success) { break }
+ $callbackUrl = "http://127.0.0.1:$port/oauth/callback"
+ try {
+ if ($errorParam) {
+ $fullUrl = $callbackUrl + "?error=" + $errorParam
+ Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null
+ $success = $true
+ } elseif ($code -and $state) {
+ $fullUrl = $callbackUrl + "?code=" + $code + "&state=" + $state
+ Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null
+ $success = $true
+ }
+ } catch {
+ # Try next port
+ }
+}
+`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
+
+ if err := os.WriteFile(scriptPath, []byte(scriptContent), 0644); err != nil {
+ return fmt.Errorf("failed to write handler script: %w", err)
+ }
+
+ // Create batch wrapper
+ batchPath := filepath.Join(scriptDir, "kiro-oauth-handler.bat")
+ batchContent := fmt.Sprintf("@echo off\npowershell -ExecutionPolicy Bypass -File \"%s\" %%1\n", scriptPath)
+
+ if err := os.WriteFile(batchPath, []byte(batchContent), 0644); err != nil {
+ return fmt.Errorf("failed to write batch wrapper: %w", err)
+ }
+
+ // Register in Windows registry
+ commands := [][]string{
+ {"reg", "add", `HKCU\Software\Classes\kiro`, "/ve", "/d", "URL:Kiro Protocol", "/f"},
+ {"reg", "add", `HKCU\Software\Classes\kiro`, "/v", "URL Protocol", "/d", "", "/f"},
+ {"reg", "add", `HKCU\Software\Classes\kiro\shell`, "/f"},
+ {"reg", "add", `HKCU\Software\Classes\kiro\shell\open`, "/f"},
+ {"reg", "add", `HKCU\Software\Classes\kiro\shell\open\command`, "/ve", "/d", fmt.Sprintf("\"%s\" \"%%1\"", batchPath), "/f"},
+ }
+
+ for _, args := range commands {
+ cmd := exec.Command(args[0], args[1:]...)
+ if err := cmd.Run(); err != nil {
+ return fmt.Errorf("failed to run registry command: %w", err)
+ }
+ }
+
+ log.Info("Kiro protocol handler installed for Windows")
+ return nil
+}
+
+func uninstallWindowsHandler() error {
+ // Remove registry keys
+ cmd := exec.Command("reg", "delete", `HKCU\Software\Classes\kiro`, "/f")
+ if err := cmd.Run(); err != nil {
+ log.Warnf("failed to remove registry key: %v", err)
+ }
+
+ // Remove scripts
+ homeDir, _ := os.UserHomeDir()
+ scriptDir := filepath.Join(homeDir, ".cliproxyapi")
+ _ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.ps1"))
+ _ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.bat"))
+
+ log.Info("Kiro protocol handler uninstalled")
+ return nil
+}
+
+// --- macOS Implementation ---
+
+func getDarwinAppPath() string {
+ homeDir, _ := os.UserHomeDir()
+ return filepath.Join(homeDir, "Applications", "KiroOAuthHandler.app")
+}
+
+func isDarwinHandlerInstalled() bool {
+ appPath := getDarwinAppPath()
+ _, err := os.Stat(appPath)
+ return err == nil
+}
+
+func installDarwinHandler(handlerPort int) error {
+ // Create app bundle structure
+ appPath := getDarwinAppPath()
+ contentsPath := filepath.Join(appPath, "Contents")
+ macOSPath := filepath.Join(contentsPath, "MacOS")
+
+ if err := os.MkdirAll(macOSPath, 0755); err != nil {
+ return fmt.Errorf("failed to create app bundle: %w", err)
+ }
+
+ // Create Info.plist
+ plistPath := filepath.Join(contentsPath, "Info.plist")
+ plistContent := `
+
+
+
+ CFBundleIdentifier
+ com.cliproxyapi.kiro-oauth-handler
+ CFBundleName
+ KiroOAuthHandler
+ CFBundleExecutable
+ kiro-oauth-handler
+ CFBundleVersion
+ 1.0
+ CFBundleURLTypes
+
+
+ CFBundleURLName
+ Kiro Protocol
+ CFBundleURLSchemes
+
+ kiro
+
+
+
+ LSBackgroundOnly
+
+
+`
+
+ if err := os.WriteFile(plistPath, []byte(plistContent), 0644); err != nil {
+ return fmt.Errorf("failed to write Info.plist: %w", err)
+ }
+
+ // Create executable script - tries multiple ports to handle dynamic port allocation
+ execPath := filepath.Join(macOSPath, "kiro-oauth-handler")
+ execContent := fmt.Sprintf(`#!/bin/bash
+# Kiro OAuth Protocol Handler for macOS
+
+URL="$1"
+
+# Check curl availability (should always exist on macOS)
+if [ ! -x /usr/bin/curl ]; then
+ echo "Error: curl is required for Kiro OAuth handler" >&2
+ exit 1
+fi
+
+# Extract code and state from URL
+[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}"
+[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}"
+[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}"
+
+# Try multiple ports (default + dynamic range)
+for PORT in %d %d %d %d %d; do
+ if [ -n "$ERROR" ]; then
+ /usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && exit 0
+ elif [ -n "$CODE" ] && [ -n "$STATE" ]; then
+ /usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && exit 0
+ fi
+done
+`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
+
+ if err := os.WriteFile(execPath, []byte(execContent), 0755); err != nil {
+ return fmt.Errorf("failed to write executable: %w", err)
+ }
+
+ // Register the app with Launch Services
+ cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister",
+ "-f", appPath)
+ if err := cmd.Run(); err != nil {
+ log.Warnf("lsregister failed (handler may still work): %v", err)
+ }
+
+ log.Info("Kiro protocol handler installed for macOS")
+ return nil
+}
+
+func uninstallDarwinHandler() error {
+ appPath := getDarwinAppPath()
+
+ // Unregister from Launch Services
+ cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister",
+ "-u", appPath)
+ _ = cmd.Run()
+
+ // Remove app bundle
+ if err := os.RemoveAll(appPath); err != nil && !os.IsNotExist(err) {
+ return fmt.Errorf("failed to remove app bundle: %w", err)
+ }
+
+ log.Info("Kiro protocol handler uninstalled")
+ return nil
+}
+
+// ParseKiroURI parses a kiro:// URI and extracts the callback parameters.
+func ParseKiroURI(rawURI string) (*AuthCallback, error) {
+ u, err := url.Parse(rawURI)
+ if err != nil {
+ return nil, fmt.Errorf("invalid URI: %w", err)
+ }
+
+ if u.Scheme != KiroProtocol {
+ return nil, fmt.Errorf("invalid scheme: expected %s, got %s", KiroProtocol, u.Scheme)
+ }
+
+ if u.Host != KiroAuthority {
+ return nil, fmt.Errorf("invalid authority: expected %s, got %s", KiroAuthority, u.Host)
+ }
+
+ query := u.Query()
+ return &AuthCallback{
+ Code: query.Get("code"),
+ State: query.Get("state"),
+ Error: query.Get("error"),
+ }, nil
+}
+
+// GetHandlerInstructions returns platform-specific instructions for manual handler setup.
+func GetHandlerInstructions() string {
+ switch runtime.GOOS {
+ case "linux":
+ return `To manually set up the Kiro protocol handler on Linux:
+
+1. Create ~/.local/share/applications/kiro-oauth-handler.desktop:
+ [Desktop Entry]
+ Name=Kiro OAuth Handler
+ Exec=~/.local/bin/kiro-oauth-handler %u
+ Type=Application
+ Terminal=false
+ MimeType=x-scheme-handler/kiro;
+
+2. Create ~/.local/bin/kiro-oauth-handler (make it executable):
+ #!/bin/bash
+ URL="$1"
+ # ... (see generated script for full content)
+
+3. Run: xdg-mime default kiro-oauth-handler.desktop x-scheme-handler/kiro`
+
+ case "windows":
+ return `To manually set up the Kiro protocol handler on Windows:
+
+1. Open Registry Editor (regedit.exe)
+2. Create key: HKEY_CURRENT_USER\Software\Classes\kiro
+3. Set default value to: URL:Kiro Protocol
+4. Create string value "URL Protocol" with empty data
+5. Create subkey: shell\open\command
+6. Set default value to: "C:\path\to\handler.bat" "%1"`
+
+ case "darwin":
+ return `To manually set up the Kiro protocol handler on macOS:
+
+1. Create ~/Applications/KiroOAuthHandler.app bundle
+2. Add Info.plist with CFBundleURLTypes containing "kiro" scheme
+3. Create executable in Contents/MacOS/
+4. Run: /System/Library/.../lsregister -f ~/Applications/KiroOAuthHandler.app`
+
+ default:
+ return "Protocol handler setup is not supported on this platform."
+ }
+}
+
+// SetupProtocolHandlerIfNeeded checks and installs the protocol handler if needed.
+func SetupProtocolHandlerIfNeeded(handlerPort int) error {
+ if IsProtocolHandlerInstalled() {
+ log.Debug("Kiro protocol handler already installed")
+ return nil
+ }
+
+ fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
+ fmt.Println("║ Kiro Protocol Handler Setup Required ║")
+ fmt.Println("╚══════════════════════════════════════════════════════════╝")
+ fmt.Println("\nTo enable Google/GitHub login, we need to install a protocol handler.")
+ fmt.Println("This allows your browser to redirect back to the CLI after authentication.")
+ fmt.Println("\nInstalling protocol handler...")
+
+ if err := InstallProtocolHandler(handlerPort); err != nil {
+ fmt.Printf("\n⚠ Automatic installation failed: %v\n", err)
+ fmt.Println("\nManual setup instructions:")
+ fmt.Println(strings.Repeat("-", 60))
+ fmt.Println(GetHandlerInstructions())
+ return err
+ }
+
+ fmt.Println("\n✓ Protocol handler installed successfully!")
+ return nil
+}
diff --git a/internal/auth/kiro/rate_limiter.go b/internal/auth/kiro/rate_limiter.go
new file mode 100644
index 00000000..52bb24af
--- /dev/null
+++ b/internal/auth/kiro/rate_limiter.go
@@ -0,0 +1,316 @@
+package kiro
+
+import (
+ "math"
+ "math/rand"
+ "strings"
+ "sync"
+ "time"
+)
+
+const (
+ DefaultMinTokenInterval = 1 * time.Second
+ DefaultMaxTokenInterval = 2 * time.Second
+ DefaultDailyMaxRequests = 500
+ DefaultJitterPercent = 0.3
+ DefaultBackoffBase = 30 * time.Second
+ DefaultBackoffMax = 5 * time.Minute
+ DefaultBackoffMultiplier = 1.5
+ DefaultSuspendCooldown = 1 * time.Hour
+)
+
+// TokenState Token 状态
+type TokenState struct {
+ LastRequest time.Time
+ RequestCount int
+ CooldownEnd time.Time
+ FailCount int
+ DailyRequests int
+ DailyResetTime time.Time
+ IsSuspended bool
+ SuspendedAt time.Time
+ SuspendReason string
+}
+
+// RateLimiter 频率限制器
+type RateLimiter struct {
+ mu sync.RWMutex
+ states map[string]*TokenState
+ minTokenInterval time.Duration
+ maxTokenInterval time.Duration
+ dailyMaxRequests int
+ jitterPercent float64
+ backoffBase time.Duration
+ backoffMax time.Duration
+ backoffMultiplier float64
+ suspendCooldown time.Duration
+ rng *rand.Rand
+}
+
+// NewRateLimiter 创建默认配置的频率限制器
+func NewRateLimiter() *RateLimiter {
+ return &RateLimiter{
+ states: make(map[string]*TokenState),
+ minTokenInterval: DefaultMinTokenInterval,
+ maxTokenInterval: DefaultMaxTokenInterval,
+ dailyMaxRequests: DefaultDailyMaxRequests,
+ jitterPercent: DefaultJitterPercent,
+ backoffBase: DefaultBackoffBase,
+ backoffMax: DefaultBackoffMax,
+ backoffMultiplier: DefaultBackoffMultiplier,
+ suspendCooldown: DefaultSuspendCooldown,
+ rng: rand.New(rand.NewSource(time.Now().UnixNano())),
+ }
+}
+
+// RateLimiterConfig 频率限制器配置
+type RateLimiterConfig struct {
+ MinTokenInterval time.Duration
+ MaxTokenInterval time.Duration
+ DailyMaxRequests int
+ JitterPercent float64
+ BackoffBase time.Duration
+ BackoffMax time.Duration
+ BackoffMultiplier float64
+ SuspendCooldown time.Duration
+}
+
+// NewRateLimiterWithConfig 使用自定义配置创建频率限制器
+func NewRateLimiterWithConfig(cfg RateLimiterConfig) *RateLimiter {
+ rl := NewRateLimiter()
+ if cfg.MinTokenInterval > 0 {
+ rl.minTokenInterval = cfg.MinTokenInterval
+ }
+ if cfg.MaxTokenInterval > 0 {
+ rl.maxTokenInterval = cfg.MaxTokenInterval
+ }
+ if cfg.DailyMaxRequests > 0 {
+ rl.dailyMaxRequests = cfg.DailyMaxRequests
+ }
+ if cfg.JitterPercent > 0 {
+ rl.jitterPercent = cfg.JitterPercent
+ }
+ if cfg.BackoffBase > 0 {
+ rl.backoffBase = cfg.BackoffBase
+ }
+ if cfg.BackoffMax > 0 {
+ rl.backoffMax = cfg.BackoffMax
+ }
+ if cfg.BackoffMultiplier > 0 {
+ rl.backoffMultiplier = cfg.BackoffMultiplier
+ }
+ if cfg.SuspendCooldown > 0 {
+ rl.suspendCooldown = cfg.SuspendCooldown
+ }
+ return rl
+}
+
+// getOrCreateState 获取或创建 Token 状态
+func (rl *RateLimiter) getOrCreateState(tokenKey string) *TokenState {
+ state, exists := rl.states[tokenKey]
+ if !exists {
+ state = &TokenState{
+ DailyResetTime: time.Now().Truncate(24 * time.Hour).Add(24 * time.Hour),
+ }
+ rl.states[tokenKey] = state
+ }
+ return state
+}
+
+// resetDailyIfNeeded 如果需要则重置每日计数
+func (rl *RateLimiter) resetDailyIfNeeded(state *TokenState) {
+ now := time.Now()
+ if now.After(state.DailyResetTime) {
+ state.DailyRequests = 0
+ state.DailyResetTime = now.Truncate(24 * time.Hour).Add(24 * time.Hour)
+ }
+}
+
+// calculateInterval 计算带抖动的随机间隔
+func (rl *RateLimiter) calculateInterval() time.Duration {
+ baseInterval := rl.minTokenInterval + time.Duration(rl.rng.Int63n(int64(rl.maxTokenInterval-rl.minTokenInterval)))
+ jitter := time.Duration(float64(baseInterval) * rl.jitterPercent * (rl.rng.Float64()*2 - 1))
+ return baseInterval + jitter
+}
+
+// WaitForToken 等待 Token 可用(带抖动的随机间隔)
+func (rl *RateLimiter) WaitForToken(tokenKey string) {
+ rl.mu.Lock()
+ state := rl.getOrCreateState(tokenKey)
+ rl.resetDailyIfNeeded(state)
+
+ now := time.Now()
+
+ // 检查是否在冷却期
+ if now.Before(state.CooldownEnd) {
+ waitTime := state.CooldownEnd.Sub(now)
+ rl.mu.Unlock()
+ time.Sleep(waitTime)
+ rl.mu.Lock()
+ state = rl.getOrCreateState(tokenKey)
+ now = time.Now()
+ }
+
+ // 计算距离上次请求的间隔
+ interval := rl.calculateInterval()
+ nextAllowedTime := state.LastRequest.Add(interval)
+
+ if now.Before(nextAllowedTime) {
+ waitTime := nextAllowedTime.Sub(now)
+ rl.mu.Unlock()
+ time.Sleep(waitTime)
+ rl.mu.Lock()
+ state = rl.getOrCreateState(tokenKey)
+ }
+
+ state.LastRequest = time.Now()
+ state.RequestCount++
+ state.DailyRequests++
+ rl.mu.Unlock()
+}
+
+// MarkTokenFailed 标记 Token 失败
+func (rl *RateLimiter) MarkTokenFailed(tokenKey string) {
+ rl.mu.Lock()
+ defer rl.mu.Unlock()
+
+ state := rl.getOrCreateState(tokenKey)
+ state.FailCount++
+ state.CooldownEnd = time.Now().Add(rl.calculateBackoff(state.FailCount))
+}
+
+// MarkTokenSuccess 标记 Token 成功
+func (rl *RateLimiter) MarkTokenSuccess(tokenKey string) {
+ rl.mu.Lock()
+ defer rl.mu.Unlock()
+
+ state := rl.getOrCreateState(tokenKey)
+ state.FailCount = 0
+ state.CooldownEnd = time.Time{}
+}
+
+// CheckAndMarkSuspended 检测暂停错误并标记
+func (rl *RateLimiter) CheckAndMarkSuspended(tokenKey string, errorMsg string) bool {
+ suspendKeywords := []string{
+ "suspended",
+ "banned",
+ "disabled",
+ "account has been",
+ "access denied",
+ "rate limit exceeded",
+ "too many requests",
+ "quota exceeded",
+ }
+
+ lowerMsg := strings.ToLower(errorMsg)
+ for _, keyword := range suspendKeywords {
+ if strings.Contains(lowerMsg, keyword) {
+ rl.mu.Lock()
+ defer rl.mu.Unlock()
+
+ state := rl.getOrCreateState(tokenKey)
+ state.IsSuspended = true
+ state.SuspendedAt = time.Now()
+ state.SuspendReason = errorMsg
+ state.CooldownEnd = time.Now().Add(rl.suspendCooldown)
+ return true
+ }
+ }
+ return false
+}
+
+// IsTokenAvailable 检查 Token 是否可用
+func (rl *RateLimiter) IsTokenAvailable(tokenKey string) bool {
+ rl.mu.RLock()
+ defer rl.mu.RUnlock()
+
+ state, exists := rl.states[tokenKey]
+ if !exists {
+ return true
+ }
+
+ now := time.Now()
+
+ // 检查是否被暂停
+ if state.IsSuspended {
+ if now.After(state.SuspendedAt.Add(rl.suspendCooldown)) {
+ return true
+ }
+ return false
+ }
+
+ // 检查是否在冷却期
+ if now.Before(state.CooldownEnd) {
+ return false
+ }
+
+ // 检查每日请求限制
+ rl.mu.RUnlock()
+ rl.mu.Lock()
+ rl.resetDailyIfNeeded(state)
+ dailyRequests := state.DailyRequests
+ dailyMax := rl.dailyMaxRequests
+ rl.mu.Unlock()
+ rl.mu.RLock()
+
+ if dailyRequests >= dailyMax {
+ return false
+ }
+
+ return true
+}
+
+// calculateBackoff 计算指数退避时间
+func (rl *RateLimiter) calculateBackoff(failCount int) time.Duration {
+ if failCount <= 0 {
+ return 0
+ }
+
+ backoff := float64(rl.backoffBase) * math.Pow(rl.backoffMultiplier, float64(failCount-1))
+
+ // 添加抖动
+ jitter := backoff * rl.jitterPercent * (rl.rng.Float64()*2 - 1)
+ backoff += jitter
+
+ if time.Duration(backoff) > rl.backoffMax {
+ return rl.backoffMax
+ }
+ return time.Duration(backoff)
+}
+
+// GetTokenState 获取 Token 状态(只读)
+func (rl *RateLimiter) GetTokenState(tokenKey string) *TokenState {
+ rl.mu.RLock()
+ defer rl.mu.RUnlock()
+
+ state, exists := rl.states[tokenKey]
+ if !exists {
+ return nil
+ }
+
+ // 返回副本以防止外部修改
+ stateCopy := *state
+ return &stateCopy
+}
+
+// ClearTokenState 清除 Token 状态
+func (rl *RateLimiter) ClearTokenState(tokenKey string) {
+ rl.mu.Lock()
+ defer rl.mu.Unlock()
+ delete(rl.states, tokenKey)
+}
+
+// ResetSuspension 重置暂停状态
+func (rl *RateLimiter) ResetSuspension(tokenKey string) {
+ rl.mu.Lock()
+ defer rl.mu.Unlock()
+
+ state, exists := rl.states[tokenKey]
+ if exists {
+ state.IsSuspended = false
+ state.SuspendedAt = time.Time{}
+ state.SuspendReason = ""
+ state.CooldownEnd = time.Time{}
+ state.FailCount = 0
+ }
+}
diff --git a/internal/auth/kiro/rate_limiter_singleton.go b/internal/auth/kiro/rate_limiter_singleton.go
new file mode 100644
index 00000000..4c02af89
--- /dev/null
+++ b/internal/auth/kiro/rate_limiter_singleton.go
@@ -0,0 +1,46 @@
+package kiro
+
+import (
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+var (
+ globalRateLimiter *RateLimiter
+ globalRateLimiterOnce sync.Once
+
+ globalCooldownManager *CooldownManager
+ globalCooldownManagerOnce sync.Once
+ cooldownStopCh chan struct{}
+)
+
+// GetGlobalRateLimiter returns the singleton RateLimiter instance.
+func GetGlobalRateLimiter() *RateLimiter {
+ globalRateLimiterOnce.Do(func() {
+ globalRateLimiter = NewRateLimiter()
+ log.Info("kiro: global RateLimiter initialized")
+ })
+ return globalRateLimiter
+}
+
+// GetGlobalCooldownManager returns the singleton CooldownManager instance.
+func GetGlobalCooldownManager() *CooldownManager {
+ globalCooldownManagerOnce.Do(func() {
+ globalCooldownManager = NewCooldownManager()
+ cooldownStopCh = make(chan struct{})
+ go globalCooldownManager.StartCleanupRoutine(5*time.Minute, cooldownStopCh)
+ log.Info("kiro: global CooldownManager initialized with cleanup routine")
+ })
+ return globalCooldownManager
+}
+
+// ShutdownRateLimiters stops the cooldown cleanup routine.
+// Should be called during application shutdown.
+func ShutdownRateLimiters() {
+ if cooldownStopCh != nil {
+ close(cooldownStopCh)
+ log.Info("kiro: rate limiter cleanup routine stopped")
+ }
+}
diff --git a/internal/auth/kiro/rate_limiter_test.go b/internal/auth/kiro/rate_limiter_test.go
new file mode 100644
index 00000000..636413dd
--- /dev/null
+++ b/internal/auth/kiro/rate_limiter_test.go
@@ -0,0 +1,304 @@
+package kiro
+
+import (
+ "sync"
+ "testing"
+ "time"
+)
+
+func TestNewRateLimiter(t *testing.T) {
+ rl := NewRateLimiter()
+ if rl == nil {
+ t.Fatal("expected non-nil RateLimiter")
+ }
+ if rl.states == nil {
+ t.Error("expected non-nil states map")
+ }
+ if rl.minTokenInterval != DefaultMinTokenInterval {
+ t.Errorf("expected minTokenInterval %v, got %v", DefaultMinTokenInterval, rl.minTokenInterval)
+ }
+ if rl.maxTokenInterval != DefaultMaxTokenInterval {
+ t.Errorf("expected maxTokenInterval %v, got %v", DefaultMaxTokenInterval, rl.maxTokenInterval)
+ }
+ if rl.dailyMaxRequests != DefaultDailyMaxRequests {
+ t.Errorf("expected dailyMaxRequests %d, got %d", DefaultDailyMaxRequests, rl.dailyMaxRequests)
+ }
+}
+
+func TestNewRateLimiterWithConfig(t *testing.T) {
+ cfg := RateLimiterConfig{
+ MinTokenInterval: 5 * time.Second,
+ MaxTokenInterval: 15 * time.Second,
+ DailyMaxRequests: 100,
+ JitterPercent: 0.2,
+ BackoffBase: 1 * time.Minute,
+ BackoffMax: 30 * time.Minute,
+ BackoffMultiplier: 1.5,
+ SuspendCooldown: 12 * time.Hour,
+ }
+
+ rl := NewRateLimiterWithConfig(cfg)
+ if rl.minTokenInterval != 5*time.Second {
+ t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
+ }
+ if rl.maxTokenInterval != 15*time.Second {
+ t.Errorf("expected maxTokenInterval 15s, got %v", rl.maxTokenInterval)
+ }
+ if rl.dailyMaxRequests != 100 {
+ t.Errorf("expected dailyMaxRequests 100, got %d", rl.dailyMaxRequests)
+ }
+}
+
+func TestNewRateLimiterWithConfig_PartialConfig(t *testing.T) {
+ cfg := RateLimiterConfig{
+ MinTokenInterval: 5 * time.Second,
+ }
+
+ rl := NewRateLimiterWithConfig(cfg)
+ if rl.minTokenInterval != 5*time.Second {
+ t.Errorf("expected minTokenInterval 5s, got %v", rl.minTokenInterval)
+ }
+ if rl.maxTokenInterval != DefaultMaxTokenInterval {
+ t.Errorf("expected default maxTokenInterval, got %v", rl.maxTokenInterval)
+ }
+}
+
+func TestGetTokenState_NonExistent(t *testing.T) {
+ rl := NewRateLimiter()
+ state := rl.GetTokenState("nonexistent")
+ if state != nil {
+ t.Error("expected nil state for non-existent token")
+ }
+}
+
+func TestIsTokenAvailable_NewToken(t *testing.T) {
+ rl := NewRateLimiter()
+ if !rl.IsTokenAvailable("newtoken") {
+ t.Error("expected new token to be available")
+ }
+}
+
+func TestMarkTokenFailed(t *testing.T) {
+ rl := NewRateLimiter()
+ rl.MarkTokenFailed("token1")
+
+ state := rl.GetTokenState("token1")
+ if state == nil {
+ t.Fatal("expected non-nil state")
+ }
+ if state.FailCount != 1 {
+ t.Errorf("expected FailCount 1, got %d", state.FailCount)
+ }
+ if state.CooldownEnd.IsZero() {
+ t.Error("expected non-zero CooldownEnd")
+ }
+}
+
+func TestMarkTokenSuccess(t *testing.T) {
+ rl := NewRateLimiter()
+ rl.MarkTokenFailed("token1")
+ rl.MarkTokenFailed("token1")
+ rl.MarkTokenSuccess("token1")
+
+ state := rl.GetTokenState("token1")
+ if state == nil {
+ t.Fatal("expected non-nil state")
+ }
+ if state.FailCount != 0 {
+ t.Errorf("expected FailCount 0, got %d", state.FailCount)
+ }
+ if !state.CooldownEnd.IsZero() {
+ t.Error("expected zero CooldownEnd after success")
+ }
+}
+
+func TestCheckAndMarkSuspended_Suspended(t *testing.T) {
+ rl := NewRateLimiter()
+
+ testCases := []string{
+ "Account has been suspended",
+ "You are banned from this service",
+ "Account disabled",
+ "Access denied permanently",
+ "Rate limit exceeded",
+ "Too many requests",
+ "Quota exceeded for today",
+ }
+
+ for i, msg := range testCases {
+ tokenKey := "token" + string(rune('a'+i))
+ if !rl.CheckAndMarkSuspended(tokenKey, msg) {
+ t.Errorf("expected suspension detected for: %s", msg)
+ }
+ state := rl.GetTokenState(tokenKey)
+ if !state.IsSuspended {
+ t.Errorf("expected IsSuspended true for: %s", msg)
+ }
+ }
+}
+
+func TestCheckAndMarkSuspended_NotSuspended(t *testing.T) {
+ rl := NewRateLimiter()
+
+ normalErrors := []string{
+ "connection timeout",
+ "internal server error",
+ "bad request",
+ "invalid token format",
+ }
+
+ for i, msg := range normalErrors {
+ tokenKey := "token" + string(rune('a'+i))
+ if rl.CheckAndMarkSuspended(tokenKey, msg) {
+ t.Errorf("unexpected suspension for: %s", msg)
+ }
+ }
+}
+
+func TestIsTokenAvailable_Suspended(t *testing.T) {
+ rl := NewRateLimiter()
+ rl.CheckAndMarkSuspended("token1", "Account suspended")
+
+ if rl.IsTokenAvailable("token1") {
+ t.Error("expected suspended token to be unavailable")
+ }
+}
+
+func TestClearTokenState(t *testing.T) {
+ rl := NewRateLimiter()
+ rl.MarkTokenFailed("token1")
+ rl.ClearTokenState("token1")
+
+ state := rl.GetTokenState("token1")
+ if state != nil {
+ t.Error("expected nil state after clear")
+ }
+}
+
+func TestResetSuspension(t *testing.T) {
+ rl := NewRateLimiter()
+ rl.CheckAndMarkSuspended("token1", "Account suspended")
+ rl.ResetSuspension("token1")
+
+ state := rl.GetTokenState("token1")
+ if state.IsSuspended {
+ t.Error("expected IsSuspended false after reset")
+ }
+ if state.FailCount != 0 {
+ t.Errorf("expected FailCount 0, got %d", state.FailCount)
+ }
+}
+
+func TestResetSuspension_NonExistent(t *testing.T) {
+ rl := NewRateLimiter()
+ rl.ResetSuspension("nonexistent")
+}
+
+func TestCalculateBackoff_ZeroFailCount(t *testing.T) {
+ rl := NewRateLimiter()
+ backoff := rl.calculateBackoff(0)
+ if backoff != 0 {
+ t.Errorf("expected 0 backoff for 0 fails, got %v", backoff)
+ }
+}
+
+func TestCalculateBackoff_Exponential(t *testing.T) {
+ cfg := RateLimiterConfig{
+ BackoffBase: 1 * time.Minute,
+ BackoffMax: 60 * time.Minute,
+ BackoffMultiplier: 2.0,
+ JitterPercent: 0.3,
+ }
+ rl := NewRateLimiterWithConfig(cfg)
+
+ backoff1 := rl.calculateBackoff(1)
+ if backoff1 < 40*time.Second || backoff1 > 80*time.Second {
+ t.Errorf("expected ~1min (with jitter) for fail 1, got %v", backoff1)
+ }
+
+ backoff2 := rl.calculateBackoff(2)
+ if backoff2 < 80*time.Second || backoff2 > 160*time.Second {
+ t.Errorf("expected ~2min (with jitter) for fail 2, got %v", backoff2)
+ }
+}
+
+func TestCalculateBackoff_MaxCap(t *testing.T) {
+ cfg := RateLimiterConfig{
+ BackoffBase: 1 * time.Minute,
+ BackoffMax: 10 * time.Minute,
+ BackoffMultiplier: 2.0,
+ JitterPercent: 0,
+ }
+ rl := NewRateLimiterWithConfig(cfg)
+
+ backoff := rl.calculateBackoff(10)
+ if backoff > 10*time.Minute {
+ t.Errorf("expected backoff capped at 10min, got %v", backoff)
+ }
+}
+
+func TestGetTokenState_ReturnsCopy(t *testing.T) {
+ rl := NewRateLimiter()
+ rl.MarkTokenFailed("token1")
+
+ state1 := rl.GetTokenState("token1")
+ state1.FailCount = 999
+
+ state2 := rl.GetTokenState("token1")
+ if state2.FailCount == 999 {
+ t.Error("GetTokenState should return a copy")
+ }
+}
+
+func TestRateLimiter_ConcurrentAccess(t *testing.T) {
+ rl := NewRateLimiter()
+ const numGoroutines = 50
+ const numOperations = 50
+
+ var wg sync.WaitGroup
+ wg.Add(numGoroutines)
+
+ for i := 0; i < numGoroutines; i++ {
+ go func(id int) {
+ defer wg.Done()
+ tokenKey := "token" + string(rune('a'+id%10))
+ for j := 0; j < numOperations; j++ {
+ switch j % 6 {
+ case 0:
+ rl.IsTokenAvailable(tokenKey)
+ case 1:
+ rl.MarkTokenFailed(tokenKey)
+ case 2:
+ rl.MarkTokenSuccess(tokenKey)
+ case 3:
+ rl.GetTokenState(tokenKey)
+ case 4:
+ rl.CheckAndMarkSuspended(tokenKey, "test error")
+ case 5:
+ rl.ResetSuspension(tokenKey)
+ }
+ }
+ }(i)
+ }
+
+ wg.Wait()
+}
+
+func TestCalculateInterval_WithinRange(t *testing.T) {
+ cfg := RateLimiterConfig{
+ MinTokenInterval: 10 * time.Second,
+ MaxTokenInterval: 30 * time.Second,
+ JitterPercent: 0.3,
+ }
+ rl := NewRateLimiterWithConfig(cfg)
+
+ minAllowed := 7 * time.Second
+ maxAllowed := 40 * time.Second
+
+ for i := 0; i < 100; i++ {
+ interval := rl.calculateInterval()
+ if interval < minAllowed || interval > maxAllowed {
+ t.Errorf("interval %v outside expected range [%v, %v]", interval, minAllowed, maxAllowed)
+ }
+ }
+}
diff --git a/internal/auth/kiro/refresh_manager.go b/internal/auth/kiro/refresh_manager.go
new file mode 100644
index 00000000..5330c5e1
--- /dev/null
+++ b/internal/auth/kiro/refresh_manager.go
@@ -0,0 +1,180 @@
+package kiro
+
+import (
+ "context"
+ "sync"
+ "time"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
+ log "github.com/sirupsen/logrus"
+)
+
+// RefreshManager 是后台刷新器的单例管理器
+type RefreshManager struct {
+ mu sync.Mutex
+ refresher *BackgroundRefresher
+ ctx context.Context
+ cancel context.CancelFunc
+ started bool
+ onTokenRefreshed func(tokenID string, tokenData *KiroTokenData) // 刷新成功回调
+}
+
+var (
+ globalRefreshManager *RefreshManager
+ managerOnce sync.Once
+)
+
+// GetRefreshManager 获取全局刷新管理器实例
+func GetRefreshManager() *RefreshManager {
+ managerOnce.Do(func() {
+ globalRefreshManager = &RefreshManager{}
+ })
+ return globalRefreshManager
+}
+
+// Initialize 初始化后台刷新器
+// baseDir: token 文件所在的目录
+// cfg: 应用配置
+func (m *RefreshManager) Initialize(baseDir string, cfg *config.Config) error {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.started {
+ log.Debug("refresh manager: already initialized")
+ return nil
+ }
+
+ if baseDir == "" {
+ log.Warn("refresh manager: base directory not provided, skipping initialization")
+ return nil
+ }
+
+ resolvedBaseDir, err := util.ResolveAuthDir(baseDir)
+ if err != nil {
+ log.Warnf("refresh manager: failed to resolve auth directory %s: %v", baseDir, err)
+ }
+ if resolvedBaseDir != "" {
+ baseDir = resolvedBaseDir
+ }
+
+ // 创建 token 存储库
+ repo := NewFileTokenRepository(baseDir)
+
+ // 创建后台刷新器,配置参数
+ opts := []RefresherOption{
+ WithInterval(time.Minute), // 每分钟检查一次
+ WithBatchSize(50), // 每批最多处理 50 个 token
+ WithConcurrency(10), // 最多 10 个并发刷新
+ WithConfig(cfg), // 设置 OAuth 和 SSO 客户端
+ }
+
+ // 如果已设置回调,传递给 BackgroundRefresher
+ if m.onTokenRefreshed != nil {
+ opts = append(opts, WithOnTokenRefreshed(m.onTokenRefreshed))
+ }
+
+ m.refresher = NewBackgroundRefresher(repo, opts...)
+
+ log.Infof("refresh manager: initialized with base directory %s", baseDir)
+ return nil
+}
+
+// Start 启动后台刷新
+func (m *RefreshManager) Start() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.started {
+ log.Debug("refresh manager: already started")
+ return
+ }
+
+ if m.refresher == nil {
+ log.Warn("refresh manager: not initialized, cannot start")
+ return
+ }
+
+ m.ctx, m.cancel = context.WithCancel(context.Background())
+ m.refresher.Start(m.ctx)
+ m.started = true
+
+ log.Info("refresh manager: background refresh started")
+}
+
+// Stop 停止后台刷新
+func (m *RefreshManager) Stop() {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if !m.started {
+ return
+ }
+
+ if m.cancel != nil {
+ m.cancel()
+ }
+
+ if m.refresher != nil {
+ m.refresher.Stop()
+ }
+
+ m.started = false
+ log.Info("refresh manager: background refresh stopped")
+}
+
+// IsRunning 检查后台刷新是否正在运行
+func (m *RefreshManager) IsRunning() bool {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ return m.started
+}
+
+// UpdateBaseDir 更新 token 目录(用于运行时配置更改)
+func (m *RefreshManager) UpdateBaseDir(baseDir string) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ if m.refresher != nil && m.refresher.tokenRepo != nil {
+ if repo, ok := m.refresher.tokenRepo.(*FileTokenRepository); ok {
+ repo.SetBaseDir(baseDir)
+ log.Infof("refresh manager: updated base directory to %s", baseDir)
+ }
+ }
+}
+
+// SetOnTokenRefreshed 设置 token 刷新成功后的回调函数
+// 可以在任何时候调用,支持运行时更新回调
+// callback: 回调函数,接收 tokenID(文件名)和新的 token 数据
+func (m *RefreshManager) SetOnTokenRefreshed(callback func(tokenID string, tokenData *KiroTokenData)) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ m.onTokenRefreshed = callback
+
+ // 如果 refresher 已经创建,使用并发安全的方式更新它的回调
+ if m.refresher != nil {
+ m.refresher.callbackMu.Lock()
+ m.refresher.onTokenRefreshed = callback
+ m.refresher.callbackMu.Unlock()
+ }
+
+ log.Debug("refresh manager: token refresh callback registered")
+}
+
+// InitializeAndStart 初始化并启动后台刷新(便捷方法)
+func InitializeAndStart(baseDir string, cfg *config.Config) {
+ manager := GetRefreshManager()
+ if err := manager.Initialize(baseDir, cfg); err != nil {
+ log.Errorf("refresh manager: initialization failed: %v", err)
+ return
+ }
+ manager.Start()
+}
+
+// StopGlobalRefreshManager 停止全局刷新管理器
+func StopGlobalRefreshManager() {
+ if globalRefreshManager != nil {
+ globalRefreshManager.Stop()
+ }
+}
diff --git a/internal/auth/kiro/refresh_utils.go b/internal/auth/kiro/refresh_utils.go
new file mode 100644
index 00000000..5abb714c
--- /dev/null
+++ b/internal/auth/kiro/refresh_utils.go
@@ -0,0 +1,159 @@
+// Package kiro provides refresh utilities for Kiro token management.
+package kiro
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+// RefreshResult contains the result of a token refresh attempt.
+type RefreshResult struct {
+ TokenData *KiroTokenData
+ Error error
+ UsedFallback bool // True if we used the existing token as fallback
+}
+
+// RefreshWithGracefulDegradation attempts to refresh a token with graceful degradation.
+// If refresh fails but the existing access token is still valid, it returns the existing token.
+// This matches kiro-openai-gateway's behavior for better reliability.
+//
+// Parameters:
+// - ctx: Context for the request
+// - refreshFunc: Function to perform the actual refresh
+// - existingAccessToken: Current access token (for fallback)
+// - expiresAt: Expiration time of the existing token
+//
+// Returns:
+// - RefreshResult containing the new or existing token data
+func RefreshWithGracefulDegradation(
+ ctx context.Context,
+ refreshFunc func(ctx context.Context) (*KiroTokenData, error),
+ existingAccessToken string,
+ expiresAt time.Time,
+) RefreshResult {
+ // Try to refresh the token
+ newTokenData, err := refreshFunc(ctx)
+ if err == nil {
+ return RefreshResult{
+ TokenData: newTokenData,
+ Error: nil,
+ UsedFallback: false,
+ }
+ }
+
+ // Refresh failed - check if we can use the existing token
+ log.Warnf("kiro: token refresh failed: %v", err)
+
+ // Check if existing token is still valid (not expired)
+ if existingAccessToken != "" && time.Now().Before(expiresAt) {
+ remainingTime := time.Until(expiresAt)
+ log.Warnf("kiro: using existing access token (expires in %v). Will retry refresh later.", remainingTime.Round(time.Second))
+
+ return RefreshResult{
+ TokenData: &KiroTokenData{
+ AccessToken: existingAccessToken,
+ ExpiresAt: expiresAt.Format(time.RFC3339),
+ },
+ Error: nil,
+ UsedFallback: true,
+ }
+ }
+
+ // Token is expired and refresh failed - return the error
+ return RefreshResult{
+ TokenData: nil,
+ Error: fmt.Errorf("token refresh failed and existing token is expired: %w", err),
+ UsedFallback: false,
+ }
+}
+
+// IsTokenExpiringSoon checks if a token is expiring within the given threshold.
+// Default threshold is 5 minutes if not specified.
+func IsTokenExpiringSoon(expiresAt time.Time, threshold time.Duration) bool {
+ if threshold == 0 {
+ threshold = 5 * time.Minute
+ }
+ return time.Now().Add(threshold).After(expiresAt)
+}
+
+// IsTokenExpired checks if a token has already expired.
+func IsTokenExpired(expiresAt time.Time) bool {
+ return time.Now().After(expiresAt)
+}
+
+// ParseExpiresAt parses an expiration time string in RFC3339 format.
+// Returns zero time if parsing fails.
+func ParseExpiresAt(expiresAtStr string) time.Time {
+ if expiresAtStr == "" {
+ return time.Time{}
+ }
+ t, err := time.Parse(time.RFC3339, expiresAtStr)
+ if err != nil {
+ log.Debugf("kiro: failed to parse expiresAt '%s': %v", expiresAtStr, err)
+ return time.Time{}
+ }
+ return t
+}
+
+// RefreshConfig contains configuration for token refresh behavior.
+type RefreshConfig struct {
+ // MaxRetries is the maximum number of refresh attempts (default: 1)
+ MaxRetries int
+ // RetryDelay is the delay between retry attempts (default: 1 second)
+ RetryDelay time.Duration
+ // RefreshThreshold is how early to refresh before expiration (default: 5 minutes)
+ RefreshThreshold time.Duration
+ // EnableGracefulDegradation allows using existing token if refresh fails (default: true)
+ EnableGracefulDegradation bool
+}
+
+// DefaultRefreshConfig returns the default refresh configuration.
+func DefaultRefreshConfig() RefreshConfig {
+ return RefreshConfig{
+ MaxRetries: 1,
+ RetryDelay: time.Second,
+ RefreshThreshold: 5 * time.Minute,
+ EnableGracefulDegradation: true,
+ }
+}
+
+// RefreshWithRetry attempts to refresh a token with retry logic.
+func RefreshWithRetry(
+ ctx context.Context,
+ refreshFunc func(ctx context.Context) (*KiroTokenData, error),
+ config RefreshConfig,
+) (*KiroTokenData, error) {
+ var lastErr error
+
+ maxAttempts := config.MaxRetries + 1
+ if maxAttempts < 1 {
+ maxAttempts = 1
+ }
+
+ for attempt := 1; attempt <= maxAttempts; attempt++ {
+ tokenData, err := refreshFunc(ctx)
+ if err == nil {
+ if attempt > 1 {
+ log.Infof("kiro: token refresh succeeded on attempt %d", attempt)
+ }
+ return tokenData, nil
+ }
+
+ lastErr = err
+ log.Warnf("kiro: token refresh attempt %d/%d failed: %v", attempt, maxAttempts, err)
+
+ // Don't sleep after the last attempt
+ if attempt < maxAttempts {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-time.After(config.RetryDelay):
+ }
+ }
+ }
+
+ return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxAttempts, lastErr)
+}
diff --git a/internal/auth/kiro/social_auth.go b/internal/auth/kiro/social_auth.go
new file mode 100644
index 00000000..65f31ba4
--- /dev/null
+++ b/internal/auth/kiro/social_auth.go
@@ -0,0 +1,481 @@
+// Package kiro provides social authentication (Google/GitHub) for Kiro via AuthServiceClient.
+package kiro
+
+import (
+ "bufio"
+ "context"
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "html"
+ "io"
+ "net"
+ "net/http"
+ "net/url"
+ "os"
+ "os/exec"
+ "runtime"
+ "strings"
+ "time"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/term"
+)
+
+const (
+ // Kiro AuthService endpoint
+ kiroAuthServiceEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
+
+ // OAuth timeout
+ socialAuthTimeout = 10 * time.Minute
+
+ // Default callback port for social auth HTTP server
+ socialAuthCallbackPort = 9876
+)
+
+// SocialProvider represents the social login provider.
+type SocialProvider string
+
+const (
+ // ProviderGoogle is Google OAuth provider
+ ProviderGoogle SocialProvider = "Google"
+ // ProviderGitHub is GitHub OAuth provider
+ ProviderGitHub SocialProvider = "Github"
+ // Note: AWS Builder ID is NOT supported by Kiro's auth service.
+ // It only supports: Google, Github, Cognito
+ // AWS Builder ID must use device code flow via SSO OIDC.
+)
+
+// CreateTokenRequest is sent to Kiro's /oauth/token endpoint.
+type CreateTokenRequest struct {
+ Code string `json:"code"`
+ CodeVerifier string `json:"code_verifier"`
+ RedirectURI string `json:"redirect_uri"`
+ InvitationCode string `json:"invitation_code,omitempty"`
+}
+
+// SocialTokenResponse from Kiro's /oauth/token endpoint for social auth.
+type SocialTokenResponse struct {
+ AccessToken string `json:"accessToken"`
+ RefreshToken string `json:"refreshToken"`
+ ProfileArn string `json:"profileArn"`
+ ExpiresIn int `json:"expiresIn"`
+}
+
+// RefreshTokenRequest is sent to Kiro's /refreshToken endpoint.
+type RefreshTokenRequest struct {
+ RefreshToken string `json:"refreshToken"`
+}
+
+// WebCallbackResult contains the OAuth callback result from HTTP server.
+type WebCallbackResult struct {
+ Code string
+ State string
+ Error string
+}
+
+// SocialAuthClient handles social authentication with Kiro.
+type SocialAuthClient struct {
+ httpClient *http.Client
+ cfg *config.Config
+ protocolHandler *ProtocolHandler
+}
+
+// NewSocialAuthClient creates a new social auth client.
+func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient {
+ client := &http.Client{Timeout: 30 * time.Second}
+ if cfg != nil {
+ client = util.SetProxy(&cfg.SDKConfig, client)
+ }
+ return &SocialAuthClient{
+ httpClient: client,
+ cfg: cfg,
+ protocolHandler: NewProtocolHandler(),
+ }
+}
+
+// startWebCallbackServer starts a local HTTP server to receive the OAuth callback.
+// This is used instead of the kiro:// protocol handler to avoid redirect_mismatch errors.
+func (c *SocialAuthClient) startWebCallbackServer(ctx context.Context, expectedState string) (string, <-chan WebCallbackResult, error) {
+ // Try to find an available port - use localhost like Kiro does
+ listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", socialAuthCallbackPort))
+ if err != nil {
+ // Try with dynamic port (RFC 8252 allows dynamic ports for native apps)
+ log.Warnf("kiro social auth: default port %d is busy, falling back to dynamic port", socialAuthCallbackPort)
+ listener, err = net.Listen("tcp", "localhost:0")
+ if err != nil {
+ return "", nil, fmt.Errorf("failed to start callback server: %w", err)
+ }
+ }
+
+ port := listener.Addr().(*net.TCPAddr).Port
+ // Use http scheme for local callback server
+ redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port)
+ resultChan := make(chan WebCallbackResult, 1)
+
+ server := &http.Server{
+ ReadHeaderTimeout: 10 * time.Second,
+ }
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
+ code := r.URL.Query().Get("code")
+ state := r.URL.Query().Get("state")
+ errParam := r.URL.Query().Get("error")
+
+ if errParam != "" {
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ w.WriteHeader(http.StatusBadRequest)
+ fmt.Fprintf(w, `
+Login Failed
+Login Failed
%s
You can close this window.
`, html.EscapeString(errParam))
+ resultChan <- WebCallbackResult{Error: errParam}
+ return
+ }
+
+ if state != expectedState {
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ w.WriteHeader(http.StatusBadRequest)
+ fmt.Fprint(w, `
+Login Failed
+Login Failed
Invalid state parameter
You can close this window.
`)
+ resultChan <- WebCallbackResult{Error: "state mismatch"}
+ return
+ }
+
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ fmt.Fprint(w, `
+Login Successful
+Login Successful!
You can close this window and return to the terminal.
+`)
+ resultChan <- WebCallbackResult{Code: code, State: state}
+ })
+
+ server.Handler = mux
+
+ go func() {
+ if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
+ log.Debugf("kiro social auth callback server error: %v", err)
+ }
+ }()
+
+ go func() {
+ select {
+ case <-ctx.Done():
+ case <-time.After(socialAuthTimeout):
+ case <-resultChan:
+ }
+ _ = server.Shutdown(context.Background())
+ }()
+
+ return redirectURI, resultChan, nil
+}
+
+// generatePKCE generates PKCE code verifier and challenge.
+func generatePKCE() (verifier, challenge string, err error) {
+ // Generate 32 bytes of random data for verifier
+ b := make([]byte, 32)
+ if _, err := rand.Read(b); err != nil {
+ return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
+ }
+ verifier = base64.RawURLEncoding.EncodeToString(b)
+
+ // Generate SHA256 hash of verifier for challenge
+ h := sha256.Sum256([]byte(verifier))
+ challenge = base64.RawURLEncoding.EncodeToString(h[:])
+
+ return verifier, challenge, nil
+}
+
+// generateState generates a random state parameter.
+func generateStateParam() (string, error) {
+ b := make([]byte, 16)
+ if _, err := rand.Read(b); err != nil {
+ return "", err
+ }
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+// buildLoginURL constructs the Kiro OAuth login URL.
+// The login endpoint expects a GET request with query parameters.
+// Format: /login?idp=Google&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=...&prompt=select_account
+// The prompt=select_account parameter forces the account selection screen even if already logged in.
+func (c *SocialAuthClient) buildLoginURL(provider, redirectURI, codeChallenge, state string) string {
+ return fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account",
+ kiroAuthServiceEndpoint,
+ provider,
+ url.QueryEscape(redirectURI),
+ codeChallenge,
+ state,
+ )
+}
+
+// CreateToken exchanges the authorization code for tokens.
+func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) {
+ body, err := json.Marshal(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal token request: %w", err)
+ }
+
+ tokenURL := kiroAuthServiceEndpoint + "/oauth/token"
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body)))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create token request: %w", err)
+ }
+
+ httpReq.Header.Set("Content-Type", "application/json")
+ httpReq.Header.Set("User-Agent", "KiroIDE-0.7.45-cli-proxy-api")
+
+ resp, err := c.httpClient.Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("token request failed: %w", err)
+ }
+ defer resp.Body.Close()
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read token response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody))
+ return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode)
+ }
+
+ var tokenResp SocialTokenResponse
+ if err := json.Unmarshal(respBody, &tokenResp); err != nil {
+ return nil, fmt.Errorf("failed to parse token response: %w", err)
+ }
+
+ return &tokenResp, nil
+}
+
+// RefreshSocialToken refreshes an expired social auth token.
+func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) {
+ body, err := json.Marshal(&RefreshTokenRequest{RefreshToken: refreshToken})
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal refresh request: %w", err)
+ }
+
+ refreshURL := kiroAuthServiceEndpoint + "/refreshToken"
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body)))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create refresh request: %w", err)
+ }
+
+ httpReq.Header.Set("Content-Type", "application/json")
+ httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
+
+ resp, err := c.httpClient.Do(httpReq)
+ if err != nil {
+ return nil, fmt.Errorf("refresh request failed: %w", err)
+ }
+ defer resp.Body.Close()
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read refresh response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
+ return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
+ }
+
+ var tokenResp SocialTokenResponse
+ if err := json.Unmarshal(respBody, &tokenResp); err != nil {
+ return nil, fmt.Errorf("failed to parse refresh response: %w", err)
+ }
+
+ // Validate ExpiresIn - use default 1 hour if invalid
+ expiresIn := tokenResp.ExpiresIn
+ if expiresIn <= 0 {
+ expiresIn = 3600 // Default 1 hour
+ }
+ expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
+
+ return &KiroTokenData{
+ AccessToken: tokenResp.AccessToken,
+ RefreshToken: tokenResp.RefreshToken,
+ ProfileArn: tokenResp.ProfileArn,
+ ExpiresAt: expiresAt.Format(time.RFC3339),
+ AuthMethod: "social",
+ Provider: "", // Caller should preserve original provider
+ Region: "us-east-1",
+ }, nil
+}
+
+// LoginWithSocial performs OAuth login with Google or GitHub.
+// Uses local HTTP callback server instead of custom protocol handler to avoid redirect_mismatch errors.
+func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) {
+ providerName := string(provider)
+
+ fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
+ fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName)
+ fmt.Println("╚══════════════════════════════════════════════════════════╝")
+
+ // Step 1: Start local HTTP callback server (instead of kiro:// protocol handler)
+ // This avoids redirect_mismatch errors with AWS Cognito
+ fmt.Println("\nSetting up authentication...")
+
+ // Step 2: Generate PKCE codes
+ codeVerifier, codeChallenge, err := generatePKCE()
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate PKCE: %w", err)
+ }
+
+ // Step 3: Generate state
+ state, err := generateStateParam()
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate state: %w", err)
+ }
+
+ // Step 4: Start local HTTP callback server
+ redirectURI, resultChan, err := c.startWebCallbackServer(ctx, state)
+ if err != nil {
+ return nil, fmt.Errorf("failed to start callback server: %w", err)
+ }
+ log.Debugf("kiro social auth: callback server started at %s", redirectURI)
+
+ // Step 5: Build the login URL using HTTP redirect URI
+ authURL := c.buildLoginURL(providerName, redirectURI, codeChallenge, state)
+
+ // Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
+ // Incognito mode enables multi-account support by bypassing cached sessions
+ if c.cfg != nil {
+ browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
+ if !c.cfg.IncognitoBrowser {
+ log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
+ } else {
+ log.Debug("kiro: using incognito mode for multi-account support")
+ }
+ } else {
+ browser.SetIncognitoMode(true) // Default to incognito if no config
+ log.Debug("kiro: using incognito mode for multi-account support (default)")
+ }
+
+ // Step 6: Open browser for user authentication
+ fmt.Println("\n════════════════════════════════════════════════════════════")
+ fmt.Printf(" Opening browser for %s authentication...\n", providerName)
+ fmt.Println("════════════════════════════════════════════════════════════")
+ fmt.Printf("\n URL: %s\n\n", authURL)
+
+ if err := browser.OpenURL(authURL); err != nil {
+ log.Warnf("Could not open browser automatically: %v", err)
+ fmt.Println(" ⚠ Could not open browser automatically.")
+ fmt.Println(" Please open the URL above in your browser manually.")
+ } else {
+ fmt.Println(" (Browser opened automatically)")
+ }
+
+ fmt.Println("\n Waiting for authentication callback...")
+
+ // Step 7: Wait for callback from HTTP server
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-time.After(socialAuthTimeout):
+ return nil, fmt.Errorf("authentication timed out")
+ case callback := <-resultChan:
+ if callback.Error != "" {
+ return nil, fmt.Errorf("authentication error: %s", callback.Error)
+ }
+
+ // State is already validated by the callback server
+ if callback.Code == "" {
+ return nil, fmt.Errorf("no authorization code received")
+ }
+
+ fmt.Println("\n✓ Authorization received!")
+
+ // Step 8: Exchange code for tokens
+ fmt.Println("Exchanging code for tokens...")
+
+ tokenReq := &CreateTokenRequest{
+ Code: callback.Code,
+ CodeVerifier: codeVerifier,
+ RedirectURI: redirectURI, // Use HTTP redirect URI, not kiro:// protocol
+ }
+
+ tokenResp, err := c.CreateToken(ctx, tokenReq)
+ if err != nil {
+ return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
+ }
+
+ fmt.Println("\n✓ Authentication successful!")
+
+ // Close the browser window
+ if err := browser.CloseBrowser(); err != nil {
+ log.Debugf("Failed to close browser: %v", err)
+ }
+
+ // Validate ExpiresIn - use default 1 hour if invalid
+ expiresIn := tokenResp.ExpiresIn
+ if expiresIn <= 0 {
+ expiresIn = 3600
+ }
+ expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
+
+ // Try to extract email from JWT access token first
+ email := ExtractEmailFromJWT(tokenResp.AccessToken)
+
+ // If no email in JWT, ask user for account label (only in interactive mode)
+ if email == "" && isInteractiveTerminal() {
+ fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
+ reader := bufio.NewReader(os.Stdin)
+ var err error
+ email, err = reader.ReadString('\n')
+ if err != nil {
+ log.Debugf("Failed to read account label: %v", err)
+ }
+ email = strings.TrimSpace(email)
+ }
+
+ return &KiroTokenData{
+ AccessToken: tokenResp.AccessToken,
+ RefreshToken: tokenResp.RefreshToken,
+ ProfileArn: tokenResp.ProfileArn,
+ ExpiresAt: expiresAt.Format(time.RFC3339),
+ AuthMethod: "social",
+ Provider: providerName,
+ Email: email, // JWT email or user-provided label
+ Region: "us-east-1",
+ }, nil
+ }
+}
+
+// LoginWithGoogle performs OAuth login with Google.
+func (c *SocialAuthClient) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
+ return c.LoginWithSocial(ctx, ProviderGoogle)
+}
+
+// LoginWithGitHub performs OAuth login with GitHub.
+func (c *SocialAuthClient) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) {
+ return c.LoginWithSocial(ctx, ProviderGitHub)
+}
+
+// forceDefaultProtocolHandler sets our protocol handler as the default for kiro:// URLs.
+// This prevents the "Open with" dialog from appearing on Linux.
+// On non-Linux platforms, this is a no-op as they use different mechanisms.
+func forceDefaultProtocolHandler() {
+ if runtime.GOOS != "linux" {
+ return // Non-Linux platforms use different handler mechanisms
+ }
+
+ // Set our handler as default using xdg-mime
+ cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
+ if err := cmd.Run(); err != nil {
+ log.Warnf("Failed to set default protocol handler: %v. You may see a handler selection dialog.", err)
+ }
+}
+
+// isInteractiveTerminal checks if stdin is connected to an interactive terminal.
+// Returns false in CI/automated environments or when stdin is piped.
+func isInteractiveTerminal() bool {
+ return term.IsTerminal(int(os.Stdin.Fd()))
+}
diff --git a/internal/auth/kiro/sso_oidc.go b/internal/auth/kiro/sso_oidc.go
new file mode 100644
index 00000000..60fb8871
--- /dev/null
+++ b/internal/auth/kiro/sso_oidc.go
@@ -0,0 +1,1380 @@
+// Package kiro provides AWS SSO OIDC authentication for Kiro.
+package kiro
+
+import (
+ "bufio"
+ "context"
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "html"
+ "io"
+ "net"
+ "net/http"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ // AWS SSO OIDC endpoints
+ ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com"
+
+ // Kiro's start URL for Builder ID
+ builderIDStartURL = "https://view.awsapps.com/start"
+
+ // Default region for IDC
+ defaultIDCRegion = "us-east-1"
+
+ // Polling interval
+ pollInterval = 5 * time.Second
+
+ // Authorization code flow callback
+ authCodeCallbackPath = "/oauth/callback"
+ authCodeCallbackPort = 19877
+
+ // User-Agent to match official Kiro IDE
+ kiroUserAgent = "KiroIDE"
+
+ // IDC token refresh headers (matching Kiro IDE behavior)
+ idcAmzUserAgent = "aws-sdk-js/3.738.0 ua/2.1 os/other lang/js md/browser#unknown_unknown api/sso-oidc#3.738.0 m/E KiroIDE"
+)
+
+// Sentinel errors for OIDC token polling
+var (
+ ErrAuthorizationPending = errors.New("authorization_pending")
+ ErrSlowDown = errors.New("slow_down")
+)
+
+// SSOOIDCClient handles AWS SSO OIDC authentication.
+type SSOOIDCClient struct {
+ httpClient *http.Client
+ cfg *config.Config
+}
+
+// NewSSOOIDCClient creates a new SSO OIDC client.
+func NewSSOOIDCClient(cfg *config.Config) *SSOOIDCClient {
+ client := &http.Client{Timeout: 30 * time.Second}
+ if cfg != nil {
+ client = util.SetProxy(&cfg.SDKConfig, client)
+ }
+ return &SSOOIDCClient{
+ httpClient: client,
+ cfg: cfg,
+ }
+}
+
+// RegisterClientResponse from AWS SSO OIDC.
+type RegisterClientResponse struct {
+ ClientID string `json:"clientId"`
+ ClientSecret string `json:"clientSecret"`
+ ClientIDIssuedAt int64 `json:"clientIdIssuedAt"`
+ ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"`
+}
+
+// StartDeviceAuthResponse from AWS SSO OIDC.
+type StartDeviceAuthResponse struct {
+ DeviceCode string `json:"deviceCode"`
+ UserCode string `json:"userCode"`
+ VerificationURI string `json:"verificationUri"`
+ VerificationURIComplete string `json:"verificationUriComplete"`
+ ExpiresIn int `json:"expiresIn"`
+ Interval int `json:"interval"`
+}
+
+// CreateTokenResponse from AWS SSO OIDC.
+type CreateTokenResponse struct {
+ AccessToken string `json:"accessToken"`
+ TokenType string `json:"tokenType"`
+ ExpiresIn int `json:"expiresIn"`
+ RefreshToken string `json:"refreshToken"`
+}
+
+// getOIDCEndpoint returns the OIDC endpoint for the given region.
+func getOIDCEndpoint(region string) string {
+ if region == "" {
+ region = defaultIDCRegion
+ }
+ return fmt.Sprintf("https://oidc.%s.amazonaws.com", region)
+}
+
+// promptInput prompts the user for input with an optional default value.
+func promptInput(prompt, defaultValue string) string {
+ reader := bufio.NewReader(os.Stdin)
+ if defaultValue != "" {
+ fmt.Printf("%s [%s]: ", prompt, defaultValue)
+ } else {
+ fmt.Printf("%s: ", prompt)
+ }
+ input, err := reader.ReadString('\n')
+ if err != nil {
+ log.Warnf("Error reading input: %v", err)
+ return defaultValue
+ }
+ input = strings.TrimSpace(input)
+ if input == "" {
+ return defaultValue
+ }
+ return input
+}
+
+// promptSelect prompts the user to select from options using number input.
+func promptSelect(prompt string, options []string) int {
+ reader := bufio.NewReader(os.Stdin)
+
+ for {
+ fmt.Println(prompt)
+ for i, opt := range options {
+ fmt.Printf(" %d) %s\n", i+1, opt)
+ }
+ fmt.Printf("Enter selection (1-%d): ", len(options))
+
+ input, err := reader.ReadString('\n')
+ if err != nil {
+ log.Warnf("Error reading input: %v", err)
+ return 0 // Default to first option on error
+ }
+ input = strings.TrimSpace(input)
+
+ // Parse the selection
+ var selection int
+ if _, err := fmt.Sscanf(input, "%d", &selection); err != nil || selection < 1 || selection > len(options) {
+ fmt.Printf("Invalid selection '%s'. Please enter a number between 1 and %d.\n\n", input, len(options))
+ continue
+ }
+ return selection - 1
+ }
+}
+
+// RegisterClientWithRegion registers a new OIDC client with AWS using a specific region.
+func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region string) (*RegisterClientResponse, error) {
+ endpoint := getOIDCEndpoint(region)
+
+ payload := map[string]interface{}{
+ "clientName": "Kiro IDE",
+ "clientType": "public",
+ "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
+ "grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"},
+ }
+
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/client/register", strings.NewReader(string(body)))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", kiroUserAgent)
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody))
+ return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
+ }
+
+ var result RegisterClientResponse
+ if err := json.Unmarshal(respBody, &result); err != nil {
+ return nil, err
+ }
+
+ return &result, nil
+}
+
+// StartDeviceAuthorizationWithIDC starts the device authorization flow for IDC.
+func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, clientID, clientSecret, startURL, region string) (*StartDeviceAuthResponse, error) {
+ endpoint := getOIDCEndpoint(region)
+
+ payload := map[string]string{
+ "clientId": clientID,
+ "clientSecret": clientSecret,
+ "startUrl": startURL,
+ }
+
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/device_authorization", strings.NewReader(string(body)))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", kiroUserAgent)
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody))
+ return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode)
+ }
+
+ var result StartDeviceAuthResponse
+ if err := json.Unmarshal(respBody, &result); err != nil {
+ return nil, err
+ }
+
+ return &result, nil
+}
+
+// CreateTokenWithRegion polls for the access token after user authorization using a specific region.
+func (c *SSOOIDCClient) CreateTokenWithRegion(ctx context.Context, clientID, clientSecret, deviceCode, region string) (*CreateTokenResponse, error) {
+ endpoint := getOIDCEndpoint(region)
+
+ payload := map[string]string{
+ "clientId": clientID,
+ "clientSecret": clientSecret,
+ "deviceCode": deviceCode,
+ "grantType": "urn:ietf:params:oauth:grant-type:device_code",
+ }
+
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", kiroUserAgent)
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ // Check for pending authorization
+ if resp.StatusCode == http.StatusBadRequest {
+ var errResp struct {
+ Error string `json:"error"`
+ }
+ if json.Unmarshal(respBody, &errResp) == nil {
+ if errResp.Error == "authorization_pending" {
+ return nil, ErrAuthorizationPending
+ }
+ if errResp.Error == "slow_down" {
+ return nil, ErrSlowDown
+ }
+ }
+ log.Debugf("create token failed: %s", string(respBody))
+ return nil, fmt.Errorf("create token failed")
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody))
+ return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
+ }
+
+ var result CreateTokenResponse
+ if err := json.Unmarshal(respBody, &result); err != nil {
+ return nil, err
+ }
+
+ return &result, nil
+}
+
+// RefreshTokenWithRegion refreshes an access token using the refresh token with a specific region.
+func (c *SSOOIDCClient) RefreshTokenWithRegion(ctx context.Context, clientID, clientSecret, refreshToken, region, startURL string) (*KiroTokenData, error) {
+ endpoint := getOIDCEndpoint(region)
+
+ payload := map[string]string{
+ "clientId": clientID,
+ "clientSecret": clientSecret,
+ "refreshToken": refreshToken,
+ "grantType": "refresh_token",
+ }
+
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/token", strings.NewReader(string(body)))
+ if err != nil {
+ return nil, err
+ }
+
+ // Set headers matching kiro2api's IDC token refresh
+ // These headers are required for successful IDC token refresh
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Host", fmt.Sprintf("oidc.%s.amazonaws.com", region))
+ req.Header.Set("Connection", "keep-alive")
+ req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
+ req.Header.Set("Accept", "*/*")
+ req.Header.Set("Accept-Language", "*")
+ req.Header.Set("sec-fetch-mode", "cors")
+ req.Header.Set("User-Agent", "node")
+ req.Header.Set("Accept-Encoding", "br, gzip, deflate")
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ log.Warnf("IDC token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
+ return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
+ }
+
+ var result CreateTokenResponse
+ if err := json.Unmarshal(respBody, &result); err != nil {
+ return nil, err
+ }
+
+ expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second)
+
+ return &KiroTokenData{
+ AccessToken: result.AccessToken,
+ RefreshToken: result.RefreshToken,
+ ExpiresAt: expiresAt.Format(time.RFC3339),
+ AuthMethod: "idc",
+ Provider: "AWS",
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ StartURL: startURL,
+ Region: region,
+ }, nil
+}
+
+// LoginWithIDC performs the full device code flow for AWS Identity Center (IDC).
+func (c *SSOOIDCClient) LoginWithIDC(ctx context.Context, startURL, region string) (*KiroTokenData, error) {
+ fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
+ fmt.Println("║ Kiro Authentication (AWS Identity Center) ║")
+ fmt.Println("╚══════════════════════════════════════════════════════════╝")
+
+ // Step 1: Register client with the specified region
+ fmt.Println("\nRegistering client...")
+ regResp, err := c.RegisterClientWithRegion(ctx, region)
+ if err != nil {
+ return nil, fmt.Errorf("failed to register client: %w", err)
+ }
+ log.Debugf("Client registered: %s", regResp.ClientID)
+
+ // Step 2: Start device authorization with IDC start URL
+ fmt.Println("Starting device authorization...")
+ authResp, err := c.StartDeviceAuthorizationWithIDC(ctx, regResp.ClientID, regResp.ClientSecret, startURL, region)
+ if err != nil {
+ return nil, fmt.Errorf("failed to start device auth: %w", err)
+ }
+
+ // Step 3: Show user the verification URL
+ fmt.Printf("\n")
+ fmt.Println("════════════════════════════════════════════════════════════")
+ fmt.Printf(" Confirm the following code in the browser:\n")
+ fmt.Printf(" Code: %s\n", authResp.UserCode)
+ fmt.Println("════════════════════════════════════════════════════════════")
+ fmt.Printf("\n Open this URL: %s\n\n", authResp.VerificationURIComplete)
+
+ // Set incognito mode based on config
+ if c.cfg != nil {
+ browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
+ if !c.cfg.IncognitoBrowser {
+ log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
+ } else {
+ log.Debug("kiro: using incognito mode for multi-account support")
+ }
+ } else {
+ browser.SetIncognitoMode(true)
+ log.Debug("kiro: using incognito mode for multi-account support (default)")
+ }
+
+ // Open browser
+ if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil {
+ log.Warnf("Could not open browser automatically: %v", err)
+ fmt.Println(" Please open the URL manually in your browser.")
+ } else {
+ fmt.Println(" (Browser opened automatically)")
+ }
+
+ // Step 4: Poll for token
+ fmt.Println("Waiting for authorization...")
+
+ interval := pollInterval
+ if authResp.Interval > 0 {
+ interval = time.Duration(authResp.Interval) * time.Second
+ }
+
+ deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
+
+ for time.Now().Before(deadline) {
+ select {
+ case <-ctx.Done():
+ browser.CloseBrowser()
+ return nil, ctx.Err()
+ case <-time.After(interval):
+ tokenResp, err := c.CreateTokenWithRegion(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode, region)
+ if err != nil {
+ if errors.Is(err, ErrAuthorizationPending) {
+ fmt.Print(".")
+ continue
+ }
+ if errors.Is(err, ErrSlowDown) {
+ interval += 5 * time.Second
+ continue
+ }
+ browser.CloseBrowser()
+ return nil, fmt.Errorf("token creation failed: %w", err)
+ }
+
+ fmt.Println("\n\n✓ Authorization successful!")
+
+ // Close the browser window
+ if err := browser.CloseBrowser(); err != nil {
+ log.Debugf("Failed to close browser: %v", err)
+ }
+
+ // Step 5: Get profile ARN from CodeWhisperer API
+ fmt.Println("Fetching profile information...")
+ profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
+
+ // Fetch user email
+ email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
+ if email != "" {
+ fmt.Printf(" Logged in as: %s\n", email)
+ }
+
+ expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
+
+ return &KiroTokenData{
+ AccessToken: tokenResp.AccessToken,
+ RefreshToken: tokenResp.RefreshToken,
+ ProfileArn: profileArn,
+ ExpiresAt: expiresAt.Format(time.RFC3339),
+ AuthMethod: "idc",
+ Provider: "AWS",
+ ClientID: regResp.ClientID,
+ ClientSecret: regResp.ClientSecret,
+ Email: email,
+ StartURL: startURL,
+ Region: region,
+ }, nil
+ }
+ }
+
+ // Close browser on timeout
+ if err := browser.CloseBrowser(); err != nil {
+ log.Debugf("Failed to close browser on timeout: %v", err)
+ }
+ return nil, fmt.Errorf("authorization timed out")
+}
+
+// LoginWithMethodSelection prompts the user to select between Builder ID and IDC, then performs the login.
+func (c *SSOOIDCClient) LoginWithMethodSelection(ctx context.Context) (*KiroTokenData, error) {
+ fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
+ fmt.Println("║ Kiro Authentication (AWS) ║")
+ fmt.Println("╚══════════════════════════════════════════════════════════╝")
+
+ // Prompt for login method
+ options := []string{
+ "Use with Builder ID (personal AWS account)",
+ "Use with IDC Account (organization SSO)",
+ }
+ selection := promptSelect("\n? Select login method:", options)
+
+ if selection == 0 {
+ // Builder ID flow - use existing implementation
+ return c.LoginWithBuilderID(ctx)
+ }
+
+ // IDC flow - prompt for start URL and region
+ fmt.Println()
+ startURL := promptInput("? Enter Start URL", "")
+ if startURL == "" {
+ return nil, fmt.Errorf("start URL is required for IDC login")
+ }
+
+ region := promptInput("? Enter Region", defaultIDCRegion)
+
+ return c.LoginWithIDC(ctx, startURL, region)
+}
+
+// RegisterClient registers a new OIDC client with AWS.
+func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) {
+ payload := map[string]interface{}{
+ "clientName": "Kiro IDE",
+ "clientType": "public",
+ "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
+ "grantTypes": []string{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"},
+ }
+
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/client/register", strings.NewReader(string(body)))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", kiroUserAgent)
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody))
+ return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
+ }
+
+ var result RegisterClientResponse
+ if err := json.Unmarshal(respBody, &result); err != nil {
+ return nil, err
+ }
+
+ return &result, nil
+}
+
+// StartDeviceAuthorization starts the device authorization flow.
+func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID, clientSecret string) (*StartDeviceAuthResponse, error) {
+ payload := map[string]string{
+ "clientId": clientID,
+ "clientSecret": clientSecret,
+ "startUrl": builderIDStartURL,
+ }
+
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/device_authorization", strings.NewReader(string(body)))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", kiroUserAgent)
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody))
+ return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode)
+ }
+
+ var result StartDeviceAuthResponse
+ if err := json.Unmarshal(respBody, &result); err != nil {
+ return nil, err
+ }
+
+ return &result, nil
+}
+
+// CreateToken polls for the access token after user authorization.
+func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret, deviceCode string) (*CreateTokenResponse, error) {
+ payload := map[string]string{
+ "clientId": clientID,
+ "clientSecret": clientSecret,
+ "deviceCode": deviceCode,
+ "grantType": "urn:ietf:params:oauth:grant-type:device_code",
+ }
+
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body)))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", kiroUserAgent)
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ // Check for pending authorization
+ if resp.StatusCode == http.StatusBadRequest {
+ var errResp struct {
+ Error string `json:"error"`
+ }
+ if json.Unmarshal(respBody, &errResp) == nil {
+ if errResp.Error == "authorization_pending" {
+ return nil, ErrAuthorizationPending
+ }
+ if errResp.Error == "slow_down" {
+ return nil, ErrSlowDown
+ }
+ }
+ log.Debugf("create token failed: %s", string(respBody))
+ return nil, fmt.Errorf("create token failed")
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody))
+ return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
+ }
+
+ var result CreateTokenResponse
+ if err := json.Unmarshal(respBody, &result); err != nil {
+ return nil, err
+ }
+
+ return &result, nil
+}
+
+// RefreshToken refreshes an access token using the refresh token.
+// Includes retry logic and improved error handling for better reliability.
+func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret, refreshToken string) (*KiroTokenData, error) {
+ payload := map[string]string{
+ "clientId": clientID,
+ "clientSecret": clientSecret,
+ "refreshToken": refreshToken,
+ "grantType": "refresh_token",
+ }
+
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body)))
+ if err != nil {
+ return nil, err
+ }
+
+ // Set headers matching Kiro IDE behavior for better compatibility
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Host", "oidc.us-east-1.amazonaws.com")
+ req.Header.Set("x-amz-user-agent", idcAmzUserAgent)
+ req.Header.Set("User-Agent", "node")
+ req.Header.Set("Accept", "*/*")
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ log.Warnf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
+ return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
+ }
+
+ var result CreateTokenResponse
+ if err := json.Unmarshal(respBody, &result); err != nil {
+ return nil, err
+ }
+
+ expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second)
+
+ return &KiroTokenData{
+ AccessToken: result.AccessToken,
+ RefreshToken: result.RefreshToken,
+ ExpiresAt: expiresAt.Format(time.RFC3339),
+ AuthMethod: "builder-id",
+ Provider: "AWS",
+ ClientID: clientID,
+ ClientSecret: clientSecret,
+ Region: defaultIDCRegion,
+ }, nil
+}
+
+// LoginWithBuilderID performs the full device code flow for AWS Builder ID.
+func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) {
+ fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
+ fmt.Println("║ Kiro Authentication (AWS Builder ID) ║")
+ fmt.Println("╚══════════════════════════════════════════════════════════╝")
+
+ // Step 1: Register client
+ fmt.Println("\nRegistering client...")
+ regResp, err := c.RegisterClient(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("failed to register client: %w", err)
+ }
+ log.Debugf("Client registered: %s", regResp.ClientID)
+
+ // Step 2: Start device authorization
+ fmt.Println("Starting device authorization...")
+ authResp, err := c.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret)
+ if err != nil {
+ return nil, fmt.Errorf("failed to start device auth: %w", err)
+ }
+
+ // Step 3: Show user the verification URL
+ fmt.Printf("\n")
+ fmt.Println("════════════════════════════════════════════════════════════")
+ fmt.Printf(" Open this URL in your browser:\n")
+ fmt.Printf(" %s\n", authResp.VerificationURIComplete)
+ fmt.Println("════════════════════════════════════════════════════════════")
+ fmt.Printf("\n Or go to: %s\n", authResp.VerificationURI)
+ fmt.Printf(" And enter code: %s\n\n", authResp.UserCode)
+
+ // Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
+ // Incognito mode enables multi-account support by bypassing cached sessions
+ if c.cfg != nil {
+ browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
+ if !c.cfg.IncognitoBrowser {
+ log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
+ } else {
+ log.Debug("kiro: using incognito mode for multi-account support")
+ }
+ } else {
+ browser.SetIncognitoMode(true) // Default to incognito if no config
+ log.Debug("kiro: using incognito mode for multi-account support (default)")
+ }
+
+ // Open browser using cross-platform browser package
+ if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil {
+ log.Warnf("Could not open browser automatically: %v", err)
+ fmt.Println(" Please open the URL manually in your browser.")
+ } else {
+ fmt.Println(" (Browser opened automatically)")
+ }
+
+ // Step 4: Poll for token
+ fmt.Println("Waiting for authorization...")
+
+ interval := pollInterval
+ if authResp.Interval > 0 {
+ interval = time.Duration(authResp.Interval) * time.Second
+ }
+
+ deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
+
+ for time.Now().Before(deadline) {
+ select {
+ case <-ctx.Done():
+ browser.CloseBrowser() // Cleanup on cancel
+ return nil, ctx.Err()
+ case <-time.After(interval):
+ tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
+ if err != nil {
+ if errors.Is(err, ErrAuthorizationPending) {
+ fmt.Print(".")
+ continue
+ }
+ if errors.Is(err, ErrSlowDown) {
+ interval += 5 * time.Second
+ continue
+ }
+ // Close browser on error before returning
+ browser.CloseBrowser()
+ return nil, fmt.Errorf("token creation failed: %w", err)
+ }
+
+ fmt.Println("\n\n✓ Authorization successful!")
+
+ // Close the browser window
+ if err := browser.CloseBrowser(); err != nil {
+ log.Debugf("Failed to close browser: %v", err)
+ }
+
+ // Step 5: Get profile ARN from CodeWhisperer API
+ fmt.Println("Fetching profile information...")
+ profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
+
+ // Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing)
+ email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
+ if email != "" {
+ fmt.Printf(" Logged in as: %s\n", email)
+ }
+
+ expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
+
+ return &KiroTokenData{
+ AccessToken: tokenResp.AccessToken,
+ RefreshToken: tokenResp.RefreshToken,
+ ProfileArn: profileArn,
+ ExpiresAt: expiresAt.Format(time.RFC3339),
+ AuthMethod: "builder-id",
+ Provider: "AWS",
+ ClientID: regResp.ClientID,
+ ClientSecret: regResp.ClientSecret,
+ Email: email,
+ Region: defaultIDCRegion,
+ }, nil
+ }
+ }
+
+ // Close browser on timeout for better UX
+ if err := browser.CloseBrowser(); err != nil {
+ log.Debugf("Failed to close browser on timeout: %v", err)
+ }
+ return nil, fmt.Errorf("authorization timed out")
+ }
+
+// FetchUserEmail retrieves the user's email from AWS SSO OIDC userinfo endpoint.
+// Falls back to JWT parsing if userinfo fails.
+func (c *SSOOIDCClient) FetchUserEmail(ctx context.Context, accessToken string) string {
+ // Method 1: Try userinfo endpoint (standard OIDC)
+ email := c.tryUserInfoEndpoint(ctx, accessToken)
+ if email != "" {
+ return email
+ }
+
+ // Method 2: Fallback to JWT parsing
+ return ExtractEmailFromJWT(accessToken)
+}
+
+// tryUserInfoEndpoint attempts to get user info from AWS SSO OIDC userinfo endpoint.
+func (c *SSOOIDCClient) tryUserInfoEndpoint(ctx context.Context, accessToken string) string {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, ssoOIDCEndpoint+"/userinfo", nil)
+ if err != nil {
+ return ""
+ }
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+ req.Header.Set("Accept", "application/json")
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ log.Debugf("userinfo request failed: %v", err)
+ return ""
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ respBody, _ := io.ReadAll(resp.Body)
+ log.Debugf("userinfo endpoint returned status %d: %s", resp.StatusCode, string(respBody))
+ return ""
+ }
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return ""
+ }
+
+ log.Debugf("userinfo response: %s", string(respBody))
+
+ var userInfo struct {
+ Email string `json:"email"`
+ Sub string `json:"sub"`
+ PreferredUsername string `json:"preferred_username"`
+ Name string `json:"name"`
+ }
+
+ if err := json.Unmarshal(respBody, &userInfo); err != nil {
+ return ""
+ }
+
+ if userInfo.Email != "" {
+ return userInfo.Email
+ }
+ if userInfo.PreferredUsername != "" && strings.Contains(userInfo.PreferredUsername, "@") {
+ return userInfo.PreferredUsername
+ }
+ return ""
+}
+
+// fetchProfileArn retrieves the profile ARN from CodeWhisperer API.
+// This is needed for file naming since AWS SSO OIDC doesn't return profile info.
+func (c *SSOOIDCClient) fetchProfileArn(ctx context.Context, accessToken string) string {
+ // Try ListProfiles API first
+ profileArn := c.tryListProfiles(ctx, accessToken)
+ if profileArn != "" {
+ return profileArn
+ }
+
+ // Fallback: Try ListAvailableCustomizations
+ return c.tryListCustomizations(ctx, accessToken)
+}
+
+func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) string {
+ payload := map[string]interface{}{
+ "origin": "AI_EDITOR",
+ }
+
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return ""
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body)))
+ if err != nil {
+ return ""
+ }
+
+ req.Header.Set("Content-Type", "application/x-amz-json-1.0")
+ req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListProfiles")
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+ req.Header.Set("Accept", "application/json")
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return ""
+ }
+ defer resp.Body.Close()
+
+ respBody, _ := io.ReadAll(resp.Body)
+
+ if resp.StatusCode != http.StatusOK {
+ log.Debugf("ListProfiles failed (status %d): %s", resp.StatusCode, string(respBody))
+ return ""
+ }
+
+ log.Debugf("ListProfiles response: %s", string(respBody))
+
+ var result struct {
+ Profiles []struct {
+ Arn string `json:"arn"`
+ } `json:"profiles"`
+ ProfileArn string `json:"profileArn"`
+ }
+
+ if err := json.Unmarshal(respBody, &result); err != nil {
+ return ""
+ }
+
+ if result.ProfileArn != "" {
+ return result.ProfileArn
+ }
+
+ if len(result.Profiles) > 0 {
+ return result.Profiles[0].Arn
+ }
+
+ return ""
+}
+
+func (c *SSOOIDCClient) tryListCustomizations(ctx context.Context, accessToken string) string {
+ payload := map[string]interface{}{
+ "origin": "AI_EDITOR",
+ }
+
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return ""
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body)))
+ if err != nil {
+ return ""
+ }
+
+ req.Header.Set("Content-Type", "application/x-amz-json-1.0")
+ req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListAvailableCustomizations")
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+ req.Header.Set("Accept", "application/json")
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return ""
+ }
+ defer resp.Body.Close()
+
+ respBody, _ := io.ReadAll(resp.Body)
+
+ if resp.StatusCode != http.StatusOK {
+ log.Debugf("ListAvailableCustomizations failed (status %d): %s", resp.StatusCode, string(respBody))
+ return ""
+ }
+
+ log.Debugf("ListAvailableCustomizations response: %s", string(respBody))
+
+ var result struct {
+ Customizations []struct {
+ Arn string `json:"arn"`
+ } `json:"customizations"`
+ ProfileArn string `json:"profileArn"`
+ }
+
+ if err := json.Unmarshal(respBody, &result); err != nil {
+ return ""
+ }
+
+ if result.ProfileArn != "" {
+ return result.ProfileArn
+ }
+
+ if len(result.Customizations) > 0 {
+ return result.Customizations[0].Arn
+ }
+
+ return ""
+}
+
+// RegisterClientForAuthCode registers a new OIDC client for authorization code flow.
+func (c *SSOOIDCClient) RegisterClientForAuthCode(ctx context.Context, redirectURI string) (*RegisterClientResponse, error) {
+ payload := map[string]interface{}{
+ "clientName": "Kiro IDE",
+ "clientType": "public",
+ "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations", "codewhisperer:transformations", "codewhisperer:taskassist"},
+ "grantTypes": []string{"authorization_code", "refresh_token"},
+ "redirectUris": []string{redirectURI},
+ "issuerUrl": builderIDStartURL,
+ }
+
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/client/register", strings.NewReader(string(body)))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", kiroUserAgent)
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ log.Debugf("register client for auth code failed (status %d): %s", resp.StatusCode, string(respBody))
+ return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
+ }
+
+ var result RegisterClientResponse
+ if err := json.Unmarshal(respBody, &result); err != nil {
+ return nil, err
+ }
+
+ return &result, nil
+}
+
+// AuthCodeCallbackResult contains the result from authorization code callback.
+type AuthCodeCallbackResult struct {
+ Code string
+ State string
+ Error string
+}
+
+// startAuthCodeCallbackServer starts a local HTTP server to receive the authorization code callback.
+func (c *SSOOIDCClient) startAuthCodeCallbackServer(ctx context.Context, expectedState string) (string, <-chan AuthCodeCallbackResult, error) {
+ // Try to find an available port
+ listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", authCodeCallbackPort))
+ if err != nil {
+ // Try with dynamic port
+ log.Warnf("sso oidc: default port %d is busy, falling back to dynamic port", authCodeCallbackPort)
+ listener, err = net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ return "", nil, fmt.Errorf("failed to start callback server: %w", err)
+ }
+ }
+
+ port := listener.Addr().(*net.TCPAddr).Port
+ redirectURI := fmt.Sprintf("http://127.0.0.1:%d%s", port, authCodeCallbackPath)
+ resultChan := make(chan AuthCodeCallbackResult, 1)
+
+ server := &http.Server{
+ ReadHeaderTimeout: 10 * time.Second,
+ }
+
+ mux := http.NewServeMux()
+ mux.HandleFunc(authCodeCallbackPath, func(w http.ResponseWriter, r *http.Request) {
+ code := r.URL.Query().Get("code")
+ state := r.URL.Query().Get("state")
+ errParam := r.URL.Query().Get("error")
+
+ // Send response to browser
+ w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ if errParam != "" {
+ w.WriteHeader(http.StatusBadRequest)
+ fmt.Fprintf(w, `
+Login Failed
+Login Failed
Error: %s
You can close this window.
`, html.EscapeString(errParam))
+ resultChan <- AuthCodeCallbackResult{Error: errParam}
+ return
+ }
+
+ if state != expectedState {
+ w.WriteHeader(http.StatusBadRequest)
+ fmt.Fprint(w, `
+Login Failed
+Login Failed
Invalid state parameter
You can close this window.
`)
+ resultChan <- AuthCodeCallbackResult{Error: "state mismatch"}
+ return
+ }
+
+ fmt.Fprint(w, `
+Login Successful
+Login Successful!
You can close this window and return to the terminal.
+`)
+ resultChan <- AuthCodeCallbackResult{Code: code, State: state}
+ })
+
+ server.Handler = mux
+
+ go func() {
+ if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
+ log.Debugf("auth code callback server error: %v", err)
+ }
+ }()
+
+ go func() {
+ select {
+ case <-ctx.Done():
+ case <-time.After(10 * time.Minute):
+ case <-resultChan:
+ }
+ _ = server.Shutdown(context.Background())
+ }()
+
+ return redirectURI, resultChan, nil
+}
+
+// generatePKCEForAuthCode generates PKCE code verifier and challenge for authorization code flow.
+func generatePKCEForAuthCode() (verifier, challenge string, err error) {
+ b := make([]byte, 32)
+ if _, err := rand.Read(b); err != nil {
+ return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
+ }
+ verifier = base64.RawURLEncoding.EncodeToString(b)
+ h := sha256.Sum256([]byte(verifier))
+ challenge = base64.RawURLEncoding.EncodeToString(h[:])
+ return verifier, challenge, nil
+}
+
+// generateStateForAuthCode generates a random state parameter.
+func generateStateForAuthCode() (string, error) {
+ b := make([]byte, 16)
+ if _, err := rand.Read(b); err != nil {
+ return "", err
+ }
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+// CreateTokenWithAuthCode exchanges authorization code for tokens.
+func (c *SSOOIDCClient) CreateTokenWithAuthCode(ctx context.Context, clientID, clientSecret, code, codeVerifier, redirectURI string) (*CreateTokenResponse, error) {
+ payload := map[string]string{
+ "clientId": clientID,
+ "clientSecret": clientSecret,
+ "code": code,
+ "codeVerifier": codeVerifier,
+ "redirectUri": redirectURI,
+ "grantType": "authorization_code",
+ }
+
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return nil, err
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body)))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", kiroUserAgent)
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ respBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ log.Debugf("create token with auth code failed (status %d): %s", resp.StatusCode, string(respBody))
+ return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
+ }
+
+ var result CreateTokenResponse
+ if err := json.Unmarshal(respBody, &result); err != nil {
+ return nil, err
+ }
+
+ return &result, nil
+}
+
+// LoginWithBuilderIDAuthCode performs the authorization code flow for AWS Builder ID.
+// This provides a better UX than device code flow as it uses automatic browser callback.
+func (c *SSOOIDCClient) LoginWithBuilderIDAuthCode(ctx context.Context) (*KiroTokenData, error) {
+ fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
+ fmt.Println("║ Kiro Authentication (AWS Builder ID - Auth Code) ║")
+ fmt.Println("╚══════════════════════════════════════════════════════════╝")
+
+ // Step 1: Generate PKCE and state
+ codeVerifier, codeChallenge, err := generatePKCEForAuthCode()
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate PKCE: %w", err)
+ }
+
+ state, err := generateStateForAuthCode()
+ if err != nil {
+ return nil, fmt.Errorf("failed to generate state: %w", err)
+ }
+
+ // Step 2: Start callback server
+ fmt.Println("\nStarting callback server...")
+ redirectURI, resultChan, err := c.startAuthCodeCallbackServer(ctx, state)
+ if err != nil {
+ return nil, fmt.Errorf("failed to start callback server: %w", err)
+ }
+ log.Debugf("Callback server started, redirect URI: %s", redirectURI)
+
+ // Step 3: Register client with auth code grant type
+ fmt.Println("Registering client...")
+ regResp, err := c.RegisterClientForAuthCode(ctx, redirectURI)
+ if err != nil {
+ return nil, fmt.Errorf("failed to register client: %w", err)
+ }
+ log.Debugf("Client registered: %s", regResp.ClientID)
+
+ // Step 4: Build authorization URL
+ scopes := "codewhisperer:completions,codewhisperer:analysis,codewhisperer:conversations"
+ authURL := fmt.Sprintf("%s/authorize?response_type=code&client_id=%s&redirect_uri=%s&scopes=%s&state=%s&code_challenge=%s&code_challenge_method=S256",
+ ssoOIDCEndpoint,
+ regResp.ClientID,
+ redirectURI,
+ scopes,
+ state,
+ codeChallenge,
+ )
+
+ // Step 5: Open browser
+ fmt.Println("\n════════════════════════════════════════════════════════════")
+ fmt.Println(" Opening browser for authentication...")
+ fmt.Println("════════════════════════════════════════════════════════════")
+ fmt.Printf("\n URL: %s\n\n", authURL)
+
+ // Set incognito mode
+ if c.cfg != nil {
+ browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
+ } else {
+ browser.SetIncognitoMode(true)
+ }
+
+ if err := browser.OpenURL(authURL); err != nil {
+ log.Warnf("Could not open browser automatically: %v", err)
+ fmt.Println(" ⚠ Could not open browser automatically.")
+ fmt.Println(" Please open the URL above in your browser manually.")
+ } else {
+ fmt.Println(" (Browser opened automatically)")
+ }
+
+ fmt.Println("\n Waiting for authorization callback...")
+
+ // Step 6: Wait for callback
+ select {
+ case <-ctx.Done():
+ browser.CloseBrowser()
+ return nil, ctx.Err()
+ case <-time.After(10 * time.Minute):
+ browser.CloseBrowser()
+ return nil, fmt.Errorf("authorization timed out")
+ case result := <-resultChan:
+ if result.Error != "" {
+ browser.CloseBrowser()
+ return nil, fmt.Errorf("authorization failed: %s", result.Error)
+ }
+
+ fmt.Println("\n✓ Authorization received!")
+
+ // Close browser
+ if err := browser.CloseBrowser(); err != nil {
+ log.Debugf("Failed to close browser: %v", err)
+ }
+
+ // Step 7: Exchange code for tokens
+ fmt.Println("Exchanging code for tokens...")
+ tokenResp, err := c.CreateTokenWithAuthCode(ctx, regResp.ClientID, regResp.ClientSecret, result.Code, codeVerifier, redirectURI)
+ if err != nil {
+ return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
+ }
+
+ fmt.Println("\n✓ Authentication successful!")
+
+ // Step 8: Get profile ARN
+ fmt.Println("Fetching profile information...")
+ profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
+
+ // Fetch user email (tries CodeWhisperer API first, then userinfo endpoint, then JWT parsing)
+ email := FetchUserEmailWithFallback(ctx, c.cfg, tokenResp.AccessToken)
+ if email != "" {
+ fmt.Printf(" Logged in as: %s\n", email)
+ }
+
+ expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
+
+ return &KiroTokenData{
+ AccessToken: tokenResp.AccessToken,
+ RefreshToken: tokenResp.RefreshToken,
+ ProfileArn: profileArn,
+ ExpiresAt: expiresAt.Format(time.RFC3339),
+ AuthMethod: "builder-id",
+ Provider: "AWS",
+ ClientID: regResp.ClientID,
+ ClientSecret: regResp.ClientSecret,
+ Email: email,
+ Region: defaultIDCRegion,
+ }, nil
+ }
+}
diff --git a/internal/auth/kiro/token.go b/internal/auth/kiro/token.go
new file mode 100644
index 00000000..0484a2dc
--- /dev/null
+++ b/internal/auth/kiro/token.go
@@ -0,0 +1,89 @@
+package kiro
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+)
+
+// KiroTokenStorage holds the persistent token data for Kiro authentication.
+type KiroTokenStorage struct {
+ // Type is the provider type for management UI recognition (must be "kiro")
+ Type string `json:"type"`
+ // AccessToken is the OAuth2 access token for API access
+ AccessToken string `json:"access_token"`
+ // RefreshToken is used to obtain new access tokens
+ RefreshToken string `json:"refresh_token"`
+ // ProfileArn is the AWS CodeWhisperer profile ARN
+ ProfileArn string `json:"profile_arn"`
+ // ExpiresAt is the timestamp when the token expires
+ ExpiresAt string `json:"expires_at"`
+ // AuthMethod indicates the authentication method used
+ AuthMethod string `json:"auth_method"`
+ // Provider indicates the OAuth provider
+ Provider string `json:"provider"`
+ // LastRefresh is the timestamp of the last token refresh
+ LastRefresh string `json:"last_refresh"`
+ // ClientID is the OAuth client ID (required for token refresh)
+ ClientID string `json:"client_id,omitempty"`
+ // ClientSecret is the OAuth client secret (required for token refresh)
+ ClientSecret string `json:"client_secret,omitempty"`
+ // Region is the AWS region
+ Region string `json:"region,omitempty"`
+ // StartURL is the AWS Identity Center start URL (for IDC auth)
+ StartURL string `json:"start_url,omitempty"`
+ // Email is the user's email address
+ Email string `json:"email,omitempty"`
+}
+
+// SaveTokenToFile persists the token storage to the specified file path.
+func (s *KiroTokenStorage) SaveTokenToFile(authFilePath string) error {
+ dir := filepath.Dir(authFilePath)
+ if err := os.MkdirAll(dir, 0700); err != nil {
+ return fmt.Errorf("failed to create directory: %w", err)
+ }
+
+ data, err := json.MarshalIndent(s, "", " ")
+ if err != nil {
+ return fmt.Errorf("failed to marshal token storage: %w", err)
+ }
+
+ if err := os.WriteFile(authFilePath, data, 0600); err != nil {
+ return fmt.Errorf("failed to write token file: %w", err)
+ }
+
+ return nil
+}
+
+// LoadFromFile loads token storage from the specified file path.
+func LoadFromFile(authFilePath string) (*KiroTokenStorage, error) {
+ data, err := os.ReadFile(authFilePath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read token file: %w", err)
+ }
+
+ var storage KiroTokenStorage
+ if err := json.Unmarshal(data, &storage); err != nil {
+ return nil, fmt.Errorf("failed to parse token file: %w", err)
+ }
+
+ return &storage, nil
+}
+
+// ToTokenData converts storage to KiroTokenData for API use.
+func (s *KiroTokenStorage) ToTokenData() *KiroTokenData {
+ return &KiroTokenData{
+ AccessToken: s.AccessToken,
+ RefreshToken: s.RefreshToken,
+ ProfileArn: s.ProfileArn,
+ ExpiresAt: s.ExpiresAt,
+ AuthMethod: s.AuthMethod,
+ Provider: s.Provider,
+ ClientID: s.ClientID,
+ ClientSecret: s.ClientSecret,
+ Region: s.Region,
+ StartURL: s.StartURL,
+ Email: s.Email,
+ }
+}
diff --git a/internal/auth/kiro/token_repository.go b/internal/auth/kiro/token_repository.go
new file mode 100644
index 00000000..815f1827
--- /dev/null
+++ b/internal/auth/kiro/token_repository.go
@@ -0,0 +1,274 @@
+package kiro
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io/fs"
+ "os"
+ "path/filepath"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+// FileTokenRepository 实现 TokenRepository 接口,基于文件系统存储
+type FileTokenRepository struct {
+ mu sync.RWMutex
+ baseDir string
+}
+
+// NewFileTokenRepository 创建一个新的文件 token 存储库
+func NewFileTokenRepository(baseDir string) *FileTokenRepository {
+ return &FileTokenRepository{
+ baseDir: baseDir,
+ }
+}
+
+// SetBaseDir 设置基础目录
+func (r *FileTokenRepository) SetBaseDir(dir string) {
+ r.mu.Lock()
+ r.baseDir = strings.TrimSpace(dir)
+ r.mu.Unlock()
+}
+
+// FindOldestUnverified 查找需要刷新的 token(按最后验证时间排序)
+func (r *FileTokenRepository) FindOldestUnverified(limit int) []*Token {
+ r.mu.RLock()
+ baseDir := r.baseDir
+ r.mu.RUnlock()
+
+ if baseDir == "" {
+ log.Debug("token repository: base directory not configured")
+ return nil
+ }
+
+ var tokens []*Token
+
+ err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error {
+ if walkErr != nil {
+ return nil // 忽略错误,继续遍历
+ }
+ if d.IsDir() {
+ return nil
+ }
+ if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") {
+ return nil
+ }
+
+ // 只处理 kiro 相关的 token 文件
+ if !strings.HasPrefix(d.Name(), "kiro-") {
+ return nil
+ }
+
+ token, err := r.readTokenFile(path)
+ if err != nil {
+ log.Debugf("token repository: failed to read token file %s: %v", path, err)
+ return nil
+ }
+
+ if token != nil && token.RefreshToken != "" {
+ // 检查 token 是否需要刷新(过期前 5 分钟)
+ if token.ExpiresAt.IsZero() || time.Until(token.ExpiresAt) < 5*time.Minute {
+ tokens = append(tokens, token)
+ }
+ }
+
+ return nil
+ })
+
+ if err != nil {
+ log.Warnf("token repository: error walking directory: %v", err)
+ }
+
+ // 按最后验证时间排序(最旧的优先)
+ sort.Slice(tokens, func(i, j int) bool {
+ return tokens[i].LastVerified.Before(tokens[j].LastVerified)
+ })
+
+ // 限制返回数量
+ if limit > 0 && len(tokens) > limit {
+ tokens = tokens[:limit]
+ }
+
+ return tokens
+}
+
+// UpdateToken 更新 token 并持久化到文件
+func (r *FileTokenRepository) UpdateToken(token *Token) error {
+ if token == nil {
+ return fmt.Errorf("token repository: token is nil")
+ }
+
+ r.mu.RLock()
+ baseDir := r.baseDir
+ r.mu.RUnlock()
+
+ if baseDir == "" {
+ return fmt.Errorf("token repository: base directory not configured")
+ }
+
+ // 构建文件路径
+ filePath := filepath.Join(baseDir, token.ID)
+ if !strings.HasSuffix(filePath, ".json") {
+ filePath += ".json"
+ }
+
+ // 读取现有文件内容
+ existingData := make(map[string]any)
+ if data, err := os.ReadFile(filePath); err == nil {
+ _ = json.Unmarshal(data, &existingData)
+ }
+
+ // 更新字段
+ existingData["access_token"] = token.AccessToken
+ existingData["refresh_token"] = token.RefreshToken
+ existingData["last_refresh"] = time.Now().Format(time.RFC3339)
+
+ if !token.ExpiresAt.IsZero() {
+ existingData["expires_at"] = token.ExpiresAt.Format(time.RFC3339)
+ }
+
+ // 保持原有的关键字段
+ if token.ClientID != "" {
+ existingData["client_id"] = token.ClientID
+ }
+ if token.ClientSecret != "" {
+ existingData["client_secret"] = token.ClientSecret
+ }
+ if token.AuthMethod != "" {
+ existingData["auth_method"] = token.AuthMethod
+ }
+ if token.Region != "" {
+ existingData["region"] = token.Region
+ }
+ if token.StartURL != "" {
+ existingData["start_url"] = token.StartURL
+ }
+
+ // 序列化并写入文件
+ raw, err := json.MarshalIndent(existingData, "", " ")
+ if err != nil {
+ return fmt.Errorf("token repository: marshal failed: %w", err)
+ }
+
+ // 原子写入:先写入临时文件,再重命名
+ tmpPath := filePath + ".tmp"
+ if err := os.WriteFile(tmpPath, raw, 0o600); err != nil {
+ return fmt.Errorf("token repository: write temp file failed: %w", err)
+ }
+ if err := os.Rename(tmpPath, filePath); err != nil {
+ _ = os.Remove(tmpPath)
+ return fmt.Errorf("token repository: rename failed: %w", err)
+ }
+
+ log.Debugf("token repository: updated token %s", token.ID)
+ return nil
+}
+
+// readTokenFile 从文件读取 token
+func (r *FileTokenRepository) readTokenFile(path string) (*Token, error) {
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+
+ var metadata map[string]any
+ if err := json.Unmarshal(data, &metadata); err != nil {
+ return nil, err
+ }
+
+ // 检查是否是 kiro token
+ tokenType, _ := metadata["type"].(string)
+ if tokenType != "kiro" {
+ return nil, nil
+ }
+
+ // 检查 auth_method (case-insensitive comparison to handle "IdC", "IDC", "idc", etc.)
+ authMethod, _ := metadata["auth_method"].(string)
+ authMethod = strings.ToLower(authMethod)
+ if authMethod != "idc" && authMethod != "builder-id" {
+ return nil, nil // 只处理 IDC 和 Builder ID token
+ }
+
+ token := &Token{
+ ID: filepath.Base(path),
+ AuthMethod: authMethod,
+ }
+
+ // 解析各字段
+ if v, ok := metadata["access_token"].(string); ok {
+ token.AccessToken = v
+ }
+ if v, ok := metadata["refresh_token"].(string); ok {
+ token.RefreshToken = v
+ }
+ if v, ok := metadata["client_id"].(string); ok {
+ token.ClientID = v
+ }
+ if v, ok := metadata["client_secret"].(string); ok {
+ token.ClientSecret = v
+ }
+ if v, ok := metadata["region"].(string); ok {
+ token.Region = v
+ }
+ if v, ok := metadata["start_url"].(string); ok {
+ token.StartURL = v
+ }
+ if v, ok := metadata["provider"].(string); ok {
+ token.Provider = v
+ }
+
+ // 解析时间字段
+ if v, ok := metadata["expires_at"].(string); ok {
+ if t, err := time.Parse(time.RFC3339, v); err == nil {
+ token.ExpiresAt = t
+ }
+ }
+ if v, ok := metadata["last_refresh"].(string); ok {
+ if t, err := time.Parse(time.RFC3339, v); err == nil {
+ token.LastVerified = t
+ }
+ }
+
+ return token, nil
+}
+
+// ListKiroTokens 列出所有 Kiro token(用于调试)
+func (r *FileTokenRepository) ListKiroTokens(ctx context.Context) ([]*Token, error) {
+ r.mu.RLock()
+ baseDir := r.baseDir
+ r.mu.RUnlock()
+
+ if baseDir == "" {
+ return nil, fmt.Errorf("token repository: base directory not configured")
+ }
+
+ var tokens []*Token
+
+ err := filepath.WalkDir(baseDir, func(path string, d fs.DirEntry, walkErr error) error {
+ if walkErr != nil {
+ return nil
+ }
+ if d.IsDir() {
+ return nil
+ }
+ if !strings.HasPrefix(d.Name(), "kiro-") || !strings.HasSuffix(d.Name(), ".json") {
+ return nil
+ }
+
+ token, err := r.readTokenFile(path)
+ if err != nil {
+ return nil
+ }
+ if token != nil {
+ tokens = append(tokens, token)
+ }
+ return nil
+ })
+
+ return tokens, err
+}
diff --git a/internal/auth/kiro/usage_checker.go b/internal/auth/kiro/usage_checker.go
new file mode 100644
index 00000000..94870214
--- /dev/null
+++ b/internal/auth/kiro/usage_checker.go
@@ -0,0 +1,243 @@
+// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
+// This file implements usage quota checking and monitoring.
+package kiro
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
+)
+
+// UsageQuotaResponse represents the API response structure for usage quota checking.
+type UsageQuotaResponse struct {
+ UsageBreakdownList []UsageBreakdownExtended `json:"usageBreakdownList"`
+ SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
+ NextDateReset float64 `json:"nextDateReset,omitempty"`
+}
+
+// UsageBreakdownExtended represents detailed usage information for quota checking.
+// Note: UsageBreakdown is already defined in codewhisperer_client.go
+type UsageBreakdownExtended struct {
+ ResourceType string `json:"resourceType"`
+ UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
+ CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
+ FreeTrialInfo *FreeTrialInfoExtended `json:"freeTrialInfo,omitempty"`
+}
+
+// FreeTrialInfoExtended represents free trial usage information.
+type FreeTrialInfoExtended struct {
+ FreeTrialStatus string `json:"freeTrialStatus"`
+ UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
+ CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
+}
+
+// QuotaStatus represents the quota status for a token.
+type QuotaStatus struct {
+ TotalLimit float64
+ CurrentUsage float64
+ RemainingQuota float64
+ IsExhausted bool
+ ResourceType string
+ NextReset time.Time
+}
+
+// UsageChecker provides methods for checking token quota usage.
+type UsageChecker struct {
+ httpClient *http.Client
+ endpoint string
+}
+
+// NewUsageChecker creates a new UsageChecker instance.
+func NewUsageChecker(cfg *config.Config) *UsageChecker {
+ return &UsageChecker{
+ httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
+ endpoint: awsKiroEndpoint,
+ }
+}
+
+// NewUsageCheckerWithClient creates a UsageChecker with a custom HTTP client.
+func NewUsageCheckerWithClient(client *http.Client) *UsageChecker {
+ return &UsageChecker{
+ httpClient: client,
+ endpoint: awsKiroEndpoint,
+ }
+}
+
+// CheckUsage retrieves usage limits for the given token.
+func (c *UsageChecker) CheckUsage(ctx context.Context, tokenData *KiroTokenData) (*UsageQuotaResponse, error) {
+ if tokenData == nil {
+ return nil, fmt.Errorf("token data is nil")
+ }
+
+ if tokenData.AccessToken == "" {
+ return nil, fmt.Errorf("access token is empty")
+ }
+
+ payload := map[string]interface{}{
+ "origin": "AI_EDITOR",
+ "profileArn": tokenData.ProfileArn,
+ "resourceType": "AGENTIC_REQUEST",
+ }
+
+ jsonBody, err := json.Marshal(payload)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint, strings.NewReader(string(jsonBody)))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/x-amz-json-1.0")
+ req.Header.Set("x-amz-target", targetGetUsage)
+ req.Header.Set("Authorization", "Bearer "+tokenData.AccessToken)
+ req.Header.Set("Accept", "application/json")
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("request failed: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
+ }
+
+ var result UsageQuotaResponse
+ if err := json.Unmarshal(body, &result); err != nil {
+ return nil, fmt.Errorf("failed to parse usage response: %w", err)
+ }
+
+ return &result, nil
+}
+
+// CheckUsageByAccessToken retrieves usage limits using an access token and profile ARN directly.
+func (c *UsageChecker) CheckUsageByAccessToken(ctx context.Context, accessToken, profileArn string) (*UsageQuotaResponse, error) {
+ tokenData := &KiroTokenData{
+ AccessToken: accessToken,
+ ProfileArn: profileArn,
+ }
+ return c.CheckUsage(ctx, tokenData)
+}
+
+// GetRemainingQuota calculates the remaining quota from usage limits.
+func GetRemainingQuota(usage *UsageQuotaResponse) float64 {
+ if usage == nil || len(usage.UsageBreakdownList) == 0 {
+ return 0
+ }
+
+ var totalRemaining float64
+ for _, breakdown := range usage.UsageBreakdownList {
+ remaining := breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
+ if remaining > 0 {
+ totalRemaining += remaining
+ }
+
+ if breakdown.FreeTrialInfo != nil {
+ freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
+ if freeRemaining > 0 {
+ totalRemaining += freeRemaining
+ }
+ }
+ }
+
+ return totalRemaining
+}
+
+// IsQuotaExhausted checks if the quota is exhausted based on usage limits.
+func IsQuotaExhausted(usage *UsageQuotaResponse) bool {
+ if usage == nil || len(usage.UsageBreakdownList) == 0 {
+ return true
+ }
+
+ for _, breakdown := range usage.UsageBreakdownList {
+ if breakdown.CurrentUsageWithPrecision < breakdown.UsageLimitWithPrecision {
+ return false
+ }
+
+ if breakdown.FreeTrialInfo != nil {
+ if breakdown.FreeTrialInfo.CurrentUsageWithPrecision < breakdown.FreeTrialInfo.UsageLimitWithPrecision {
+ return false
+ }
+ }
+ }
+
+ return true
+}
+
+// GetQuotaStatus retrieves a comprehensive quota status for a token.
+func (c *UsageChecker) GetQuotaStatus(ctx context.Context, tokenData *KiroTokenData) (*QuotaStatus, error) {
+ usage, err := c.CheckUsage(ctx, tokenData)
+ if err != nil {
+ return nil, err
+ }
+
+ status := &QuotaStatus{
+ IsExhausted: IsQuotaExhausted(usage),
+ }
+
+ if len(usage.UsageBreakdownList) > 0 {
+ breakdown := usage.UsageBreakdownList[0]
+ status.TotalLimit = breakdown.UsageLimitWithPrecision
+ status.CurrentUsage = breakdown.CurrentUsageWithPrecision
+ status.RemainingQuota = breakdown.UsageLimitWithPrecision - breakdown.CurrentUsageWithPrecision
+ status.ResourceType = breakdown.ResourceType
+
+ if breakdown.FreeTrialInfo != nil {
+ status.TotalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
+ status.CurrentUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
+ freeRemaining := breakdown.FreeTrialInfo.UsageLimitWithPrecision - breakdown.FreeTrialInfo.CurrentUsageWithPrecision
+ if freeRemaining > 0 {
+ status.RemainingQuota += freeRemaining
+ }
+ }
+ }
+
+ if usage.NextDateReset > 0 {
+ status.NextReset = time.Unix(int64(usage.NextDateReset/1000), 0)
+ }
+
+ return status, nil
+}
+
+// CalculateAvailableCount calculates the available request count based on usage limits.
+func CalculateAvailableCount(usage *UsageQuotaResponse) float64 {
+ return GetRemainingQuota(usage)
+}
+
+// GetUsagePercentage calculates the usage percentage.
+func GetUsagePercentage(usage *UsageQuotaResponse) float64 {
+ if usage == nil || len(usage.UsageBreakdownList) == 0 {
+ return 100.0
+ }
+
+ var totalLimit, totalUsage float64
+ for _, breakdown := range usage.UsageBreakdownList {
+ totalLimit += breakdown.UsageLimitWithPrecision
+ totalUsage += breakdown.CurrentUsageWithPrecision
+
+ if breakdown.FreeTrialInfo != nil {
+ totalLimit += breakdown.FreeTrialInfo.UsageLimitWithPrecision
+ totalUsage += breakdown.FreeTrialInfo.CurrentUsageWithPrecision
+ }
+ }
+
+ if totalLimit == 0 {
+ return 100.0
+ }
+
+ return (totalUsage / totalLimit) * 100
+}
diff --git a/internal/browser/browser.go b/internal/browser/browser.go
index b24dc5e1..3a5aeea7 100644
--- a/internal/browser/browser.go
+++ b/internal/browser/browser.go
@@ -6,14 +6,49 @@ import (
"fmt"
"os/exec"
"runtime"
+ "strings"
+ "sync"
+ pkgbrowser "github.com/pkg/browser"
log "github.com/sirupsen/logrus"
- "github.com/skratchdot/open-golang/open"
)
+// incognitoMode controls whether to open URLs in incognito/private mode.
+// This is useful for OAuth flows where you want to use a different account.
+var incognitoMode bool
+
+// lastBrowserProcess stores the last opened browser process for cleanup
+var lastBrowserProcess *exec.Cmd
+var browserMutex sync.Mutex
+
+// SetIncognitoMode enables or disables incognito/private browsing mode.
+func SetIncognitoMode(enabled bool) {
+ incognitoMode = enabled
+}
+
+// IsIncognitoMode returns whether incognito mode is enabled.
+func IsIncognitoMode() bool {
+ return incognitoMode
+}
+
+// CloseBrowser closes the last opened browser process.
+func CloseBrowser() error {
+ browserMutex.Lock()
+ defer browserMutex.Unlock()
+
+ if lastBrowserProcess == nil || lastBrowserProcess.Process == nil {
+ return nil
+ }
+
+ err := lastBrowserProcess.Process.Kill()
+ lastBrowserProcess = nil
+ return err
+}
+
// OpenURL opens the specified URL in the default web browser.
-// It first attempts to use a platform-agnostic library and falls back to
-// platform-specific commands if that fails.
+// It uses the pkg/browser library which provides robust cross-platform support
+// for Windows, macOS, and Linux.
+// If incognito mode is enabled, it will open in a private/incognito window.
//
// Parameters:
// - url: The URL to open.
@@ -21,16 +56,22 @@ import (
// Returns:
// - An error if the URL cannot be opened, otherwise nil.
func OpenURL(url string) error {
- fmt.Printf("Attempting to open URL in browser: %s\n", url)
+ log.Debugf("Opening URL in browser: %s (incognito=%v)", url, incognitoMode)
- // Try using the open-golang library first
- err := open.Run(url)
+ // If incognito mode is enabled, use platform-specific incognito commands
+ if incognitoMode {
+ log.Debug("Using incognito mode")
+ return openURLIncognito(url)
+ }
+
+ // Use pkg/browser for cross-platform support
+ err := pkgbrowser.OpenURL(url)
if err == nil {
- log.Debug("Successfully opened URL using open-golang library")
+ log.Debug("Successfully opened URL using pkg/browser library")
return nil
}
- log.Debugf("open-golang failed: %v, trying platform-specific commands", err)
+ log.Debugf("pkg/browser failed: %v, trying platform-specific commands", err)
// Fallback to platform-specific commands
return openURLPlatformSpecific(url)
@@ -78,18 +119,379 @@ func openURLPlatformSpecific(url string) error {
return nil
}
+// openURLIncognito opens a URL in incognito/private browsing mode.
+// It first tries to detect the default browser and use its incognito flag.
+// Falls back to a chain of known browsers if detection fails.
+//
+// Parameters:
+// - url: The URL to open.
+//
+// Returns:
+// - An error if the URL cannot be opened, otherwise nil.
+func openURLIncognito(url string) error {
+ // First, try to detect and use the default browser
+ if cmd := tryDefaultBrowserIncognito(url); cmd != nil {
+ log.Debugf("Using detected default browser: %s %v", cmd.Path, cmd.Args[1:])
+ if err := cmd.Start(); err == nil {
+ storeBrowserProcess(cmd)
+ log.Debug("Successfully opened URL in default browser's incognito mode")
+ return nil
+ }
+ log.Debugf("Failed to start default browser, trying fallback chain")
+ }
+
+ // Fallback to known browser chain
+ cmd := tryFallbackBrowsersIncognito(url)
+ if cmd == nil {
+ log.Warn("No browser with incognito support found, falling back to normal mode")
+ return openURLPlatformSpecific(url)
+ }
+
+ log.Debugf("Running incognito command: %s %v", cmd.Path, cmd.Args[1:])
+ err := cmd.Start()
+ if err != nil {
+ log.Warnf("Failed to open incognito browser: %v, falling back to normal mode", err)
+ return openURLPlatformSpecific(url)
+ }
+
+ storeBrowserProcess(cmd)
+ log.Debug("Successfully opened URL in incognito/private mode")
+ return nil
+}
+
+// storeBrowserProcess safely stores the browser process for later cleanup.
+func storeBrowserProcess(cmd *exec.Cmd) {
+ browserMutex.Lock()
+ lastBrowserProcess = cmd
+ browserMutex.Unlock()
+}
+
+// tryDefaultBrowserIncognito attempts to detect the default browser and return
+// an exec.Cmd configured with the appropriate incognito flag.
+func tryDefaultBrowserIncognito(url string) *exec.Cmd {
+ switch runtime.GOOS {
+ case "darwin":
+ return tryDefaultBrowserMacOS(url)
+ case "windows":
+ return tryDefaultBrowserWindows(url)
+ case "linux":
+ return tryDefaultBrowserLinux(url)
+ }
+ return nil
+}
+
+// tryDefaultBrowserMacOS detects the default browser on macOS.
+func tryDefaultBrowserMacOS(url string) *exec.Cmd {
+ // Try to get default browser from Launch Services
+ out, err := exec.Command("defaults", "read", "com.apple.LaunchServices/com.apple.launchservices.secure", "LSHandlers").Output()
+ if err != nil {
+ return nil
+ }
+
+ output := string(out)
+ var browserName string
+
+ // Parse the output to find the http/https handler
+ if containsBrowserID(output, "com.google.chrome") {
+ browserName = "chrome"
+ } else if containsBrowserID(output, "org.mozilla.firefox") {
+ browserName = "firefox"
+ } else if containsBrowserID(output, "com.apple.safari") {
+ browserName = "safari"
+ } else if containsBrowserID(output, "com.brave.browser") {
+ browserName = "brave"
+ } else if containsBrowserID(output, "com.microsoft.edgemac") {
+ browserName = "edge"
+ }
+
+ return createMacOSIncognitoCmd(browserName, url)
+}
+
+// containsBrowserID checks if the LaunchServices output contains a browser ID.
+func containsBrowserID(output, bundleID string) bool {
+ return strings.Contains(output, bundleID)
+}
+
+// createMacOSIncognitoCmd creates the appropriate incognito command for macOS browsers.
+func createMacOSIncognitoCmd(browserName, url string) *exec.Cmd {
+ switch browserName {
+ case "chrome":
+ // Try direct path first
+ chromePath := "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
+ if _, err := exec.LookPath(chromePath); err == nil {
+ return exec.Command(chromePath, "--incognito", url)
+ }
+ return exec.Command("open", "-na", "Google Chrome", "--args", "--incognito", url)
+ case "firefox":
+ return exec.Command("open", "-na", "Firefox", "--args", "--private-window", url)
+ case "safari":
+ // Safari doesn't have CLI incognito, try AppleScript
+ return tryAppleScriptSafariPrivate(url)
+ case "brave":
+ return exec.Command("open", "-na", "Brave Browser", "--args", "--incognito", url)
+ case "edge":
+ return exec.Command("open", "-na", "Microsoft Edge", "--args", "--inprivate", url)
+ }
+ return nil
+}
+
+// tryAppleScriptSafariPrivate attempts to open Safari in private browsing mode using AppleScript.
+func tryAppleScriptSafariPrivate(url string) *exec.Cmd {
+ // AppleScript to open a new private window in Safari
+ script := fmt.Sprintf(`
+ tell application "Safari"
+ activate
+ tell application "System Events"
+ keystroke "n" using {command down, shift down}
+ delay 0.5
+ end tell
+ set URL of document 1 to "%s"
+ end tell
+ `, url)
+
+ cmd := exec.Command("osascript", "-e", script)
+ // Test if this approach works by checking if Safari is available
+ if _, err := exec.LookPath("/Applications/Safari.app/Contents/MacOS/Safari"); err != nil {
+ log.Debug("Safari not found, AppleScript private window not available")
+ return nil
+ }
+ log.Debug("Attempting Safari private window via AppleScript")
+ return cmd
+}
+
+// tryDefaultBrowserWindows detects the default browser on Windows via registry.
+func tryDefaultBrowserWindows(url string) *exec.Cmd {
+ // Query registry for default browser
+ out, err := exec.Command("reg", "query",
+ `HKEY_CURRENT_USER\Software\Microsoft\Windows\Shell\Associations\UrlAssociations\http\UserChoice`,
+ "/v", "ProgId").Output()
+ if err != nil {
+ return nil
+ }
+
+ output := string(out)
+ var browserName string
+
+ // Map ProgId to browser name
+ if strings.Contains(output, "ChromeHTML") {
+ browserName = "chrome"
+ } else if strings.Contains(output, "FirefoxURL") {
+ browserName = "firefox"
+ } else if strings.Contains(output, "MSEdgeHTM") {
+ browserName = "edge"
+ } else if strings.Contains(output, "BraveHTML") {
+ browserName = "brave"
+ }
+
+ return createWindowsIncognitoCmd(browserName, url)
+}
+
+// createWindowsIncognitoCmd creates the appropriate incognito command for Windows browsers.
+func createWindowsIncognitoCmd(browserName, url string) *exec.Cmd {
+ switch browserName {
+ case "chrome":
+ paths := []string{
+ "chrome",
+ `C:\Program Files\Google\Chrome\Application\chrome.exe`,
+ `C:\Program Files (x86)\Google\Chrome\Application\chrome.exe`,
+ }
+ for _, p := range paths {
+ if _, err := exec.LookPath(p); err == nil {
+ return exec.Command(p, "--incognito", url)
+ }
+ }
+ case "firefox":
+ if path, err := exec.LookPath("firefox"); err == nil {
+ return exec.Command(path, "--private-window", url)
+ }
+ case "edge":
+ paths := []string{
+ "msedge",
+ `C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe`,
+ `C:\Program Files\Microsoft\Edge\Application\msedge.exe`,
+ }
+ for _, p := range paths {
+ if _, err := exec.LookPath(p); err == nil {
+ return exec.Command(p, "--inprivate", url)
+ }
+ }
+ case "brave":
+ paths := []string{
+ `C:\Program Files\BraveSoftware\Brave-Browser\Application\brave.exe`,
+ `C:\Program Files (x86)\BraveSoftware\Brave-Browser\Application\brave.exe`,
+ }
+ for _, p := range paths {
+ if _, err := exec.LookPath(p); err == nil {
+ return exec.Command(p, "--incognito", url)
+ }
+ }
+ }
+ return nil
+}
+
+// tryDefaultBrowserLinux detects the default browser on Linux using xdg-settings.
+func tryDefaultBrowserLinux(url string) *exec.Cmd {
+ out, err := exec.Command("xdg-settings", "get", "default-web-browser").Output()
+ if err != nil {
+ return nil
+ }
+
+ desktop := string(out)
+ var browserName string
+
+ // Map .desktop file to browser name
+ if strings.Contains(desktop, "google-chrome") || strings.Contains(desktop, "chrome") {
+ browserName = "chrome"
+ } else if strings.Contains(desktop, "firefox") {
+ browserName = "firefox"
+ } else if strings.Contains(desktop, "chromium") {
+ browserName = "chromium"
+ } else if strings.Contains(desktop, "brave") {
+ browserName = "brave"
+ } else if strings.Contains(desktop, "microsoft-edge") || strings.Contains(desktop, "msedge") {
+ browserName = "edge"
+ }
+
+ return createLinuxIncognitoCmd(browserName, url)
+}
+
+// createLinuxIncognitoCmd creates the appropriate incognito command for Linux browsers.
+func createLinuxIncognitoCmd(browserName, url string) *exec.Cmd {
+ switch browserName {
+ case "chrome":
+ paths := []string{"google-chrome", "google-chrome-stable"}
+ for _, p := range paths {
+ if path, err := exec.LookPath(p); err == nil {
+ return exec.Command(path, "--incognito", url)
+ }
+ }
+ case "firefox":
+ paths := []string{"firefox", "firefox-esr"}
+ for _, p := range paths {
+ if path, err := exec.LookPath(p); err == nil {
+ return exec.Command(path, "--private-window", url)
+ }
+ }
+ case "chromium":
+ paths := []string{"chromium", "chromium-browser"}
+ for _, p := range paths {
+ if path, err := exec.LookPath(p); err == nil {
+ return exec.Command(path, "--incognito", url)
+ }
+ }
+ case "brave":
+ if path, err := exec.LookPath("brave-browser"); err == nil {
+ return exec.Command(path, "--incognito", url)
+ }
+ case "edge":
+ if path, err := exec.LookPath("microsoft-edge"); err == nil {
+ return exec.Command(path, "--inprivate", url)
+ }
+ }
+ return nil
+}
+
+// tryFallbackBrowsersIncognito tries a chain of known browsers as fallback.
+func tryFallbackBrowsersIncognito(url string) *exec.Cmd {
+ switch runtime.GOOS {
+ case "darwin":
+ return tryFallbackBrowsersMacOS(url)
+ case "windows":
+ return tryFallbackBrowsersWindows(url)
+ case "linux":
+ return tryFallbackBrowsersLinuxChain(url)
+ }
+ return nil
+}
+
+// tryFallbackBrowsersMacOS tries known browsers on macOS.
+func tryFallbackBrowsersMacOS(url string) *exec.Cmd {
+ // Try Chrome
+ chromePath := "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
+ if _, err := exec.LookPath(chromePath); err == nil {
+ return exec.Command(chromePath, "--incognito", url)
+ }
+ // Try Firefox
+ if _, err := exec.LookPath("/Applications/Firefox.app/Contents/MacOS/firefox"); err == nil {
+ return exec.Command("open", "-na", "Firefox", "--args", "--private-window", url)
+ }
+ // Try Brave
+ if _, err := exec.LookPath("/Applications/Brave Browser.app/Contents/MacOS/Brave Browser"); err == nil {
+ return exec.Command("open", "-na", "Brave Browser", "--args", "--incognito", url)
+ }
+ // Try Edge
+ if _, err := exec.LookPath("/Applications/Microsoft Edge.app/Contents/MacOS/Microsoft Edge"); err == nil {
+ return exec.Command("open", "-na", "Microsoft Edge", "--args", "--inprivate", url)
+ }
+ // Last resort: try Safari with AppleScript
+ if cmd := tryAppleScriptSafariPrivate(url); cmd != nil {
+ log.Info("Using Safari with AppleScript for private browsing (may require accessibility permissions)")
+ return cmd
+ }
+ return nil
+}
+
+// tryFallbackBrowsersWindows tries known browsers on Windows.
+func tryFallbackBrowsersWindows(url string) *exec.Cmd {
+ // Chrome
+ chromePaths := []string{
+ "chrome",
+ `C:\Program Files\Google\Chrome\Application\chrome.exe`,
+ `C:\Program Files (x86)\Google\Chrome\Application\chrome.exe`,
+ }
+ for _, p := range chromePaths {
+ if _, err := exec.LookPath(p); err == nil {
+ return exec.Command(p, "--incognito", url)
+ }
+ }
+ // Firefox
+ if path, err := exec.LookPath("firefox"); err == nil {
+ return exec.Command(path, "--private-window", url)
+ }
+ // Edge (usually available on Windows 10+)
+ edgePaths := []string{
+ "msedge",
+ `C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe`,
+ `C:\Program Files\Microsoft\Edge\Application\msedge.exe`,
+ }
+ for _, p := range edgePaths {
+ if _, err := exec.LookPath(p); err == nil {
+ return exec.Command(p, "--inprivate", url)
+ }
+ }
+ return nil
+}
+
+// tryFallbackBrowsersLinuxChain tries known browsers on Linux.
+func tryFallbackBrowsersLinuxChain(url string) *exec.Cmd {
+ type browserConfig struct {
+ name string
+ flag string
+ }
+ browsers := []browserConfig{
+ {"google-chrome", "--incognito"},
+ {"google-chrome-stable", "--incognito"},
+ {"chromium", "--incognito"},
+ {"chromium-browser", "--incognito"},
+ {"firefox", "--private-window"},
+ {"firefox-esr", "--private-window"},
+ {"brave-browser", "--incognito"},
+ {"microsoft-edge", "--inprivate"},
+ }
+ for _, b := range browsers {
+ if path, err := exec.LookPath(b.name); err == nil {
+ return exec.Command(path, b.flag, url)
+ }
+ }
+ return nil
+}
+
// IsAvailable checks if the system has a command available to open a web browser.
// It verifies the presence of necessary commands for the current operating system.
//
// Returns:
// - true if a browser can be opened, false otherwise.
func IsAvailable() bool {
- // First check if open-golang can work
- testErr := open.Run("about:blank")
- if testErr == nil {
- return true
- }
-
// Check platform-specific commands
switch runtime.GOOS {
case "darwin":
diff --git a/internal/cmd/auth_manager.go b/internal/cmd/auth_manager.go
index 7fa1d88e..2a3407be 100644
--- a/internal/cmd/auth_manager.go
+++ b/internal/cmd/auth_manager.go
@@ -6,7 +6,7 @@ import (
// newAuthManager creates a new authentication manager instance with all supported
// authenticators and a file-based token store. It initializes authenticators for
-// Gemini, Codex, Claude, and Qwen providers.
+// Gemini, Codex, Claude, Qwen, IFlow, Antigravity, and GitHub Copilot providers.
//
// Returns:
// - *sdkAuth.Manager: A configured authentication manager instance
@@ -20,6 +20,9 @@ func newAuthManager() *sdkAuth.Manager {
sdkAuth.NewIFlowAuthenticator(),
sdkAuth.NewAntigravityAuthenticator(),
sdkAuth.NewKimiAuthenticator(),
+ sdkAuth.NewKiroAuthenticator(),
+ sdkAuth.NewGitHubCopilotAuthenticator(),
+ sdkAuth.NewKiloAuthenticator(),
)
return manager
}
diff --git a/internal/cmd/github_copilot_login.go b/internal/cmd/github_copilot_login.go
new file mode 100644
index 00000000..056e811f
--- /dev/null
+++ b/internal/cmd/github_copilot_login.go
@@ -0,0 +1,44 @@
+package cmd
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
+ log "github.com/sirupsen/logrus"
+)
+
+// DoGitHubCopilotLogin triggers the OAuth device flow for GitHub Copilot and saves tokens.
+// It initiates the device flow authentication, displays the user code for the user to enter
+// at GitHub's verification URL, and waits for authorization before saving the tokens.
+//
+// Parameters:
+// - cfg: The application configuration containing proxy and auth directory settings
+// - options: Login options including browser behavior settings
+func DoGitHubCopilotLogin(cfg *config.Config, options *LoginOptions) {
+ if options == nil {
+ options = &LoginOptions{}
+ }
+
+ manager := newAuthManager()
+ authOpts := &sdkAuth.LoginOptions{
+ NoBrowser: options.NoBrowser,
+ Metadata: map[string]string{},
+ Prompt: options.Prompt,
+ }
+
+ record, savedPath, err := manager.Login(context.Background(), "github-copilot", cfg, authOpts)
+ if err != nil {
+ log.Errorf("GitHub Copilot authentication failed: %v", err)
+ return
+ }
+
+ if savedPath != "" {
+ fmt.Printf("Authentication saved to %s\n", savedPath)
+ }
+ if record != nil && record.Label != "" {
+ fmt.Printf("Authenticated as %s\n", record.Label)
+ }
+ fmt.Println("GitHub Copilot authentication successful!")
+}
diff --git a/internal/cmd/kilo_login.go b/internal/cmd/kilo_login.go
new file mode 100644
index 00000000..7e9ed3b9
--- /dev/null
+++ b/internal/cmd/kilo_login.go
@@ -0,0 +1,54 @@
+package cmd
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
+)
+
+// DoKiloLogin handles the Kilo device flow using the shared authentication manager.
+// It initiates the device-based authentication process for Kilo AI services and saves
+// the authentication tokens to the configured auth directory.
+//
+// Parameters:
+// - cfg: The application configuration
+// - options: Login options including browser behavior and prompts
+func DoKiloLogin(cfg *config.Config, options *LoginOptions) {
+ if options == nil {
+ options = &LoginOptions{}
+ }
+
+ manager := newAuthManager()
+
+ promptFn := options.Prompt
+ if promptFn == nil {
+ promptFn = func(prompt string) (string, error) {
+ fmt.Print(prompt)
+ var value string
+ fmt.Scanln(&value)
+ return strings.TrimSpace(value), nil
+ }
+ }
+
+ authOpts := &sdkAuth.LoginOptions{
+ NoBrowser: options.NoBrowser,
+ CallbackPort: options.CallbackPort,
+ Metadata: map[string]string{},
+ Prompt: promptFn,
+ }
+
+ _, savedPath, err := manager.Login(context.Background(), "kilo", cfg, authOpts)
+ if err != nil {
+ fmt.Printf("Kilo authentication failed: %v\n", err)
+ return
+ }
+
+ if savedPath != "" {
+ fmt.Printf("Authentication saved to %s\n", savedPath)
+ }
+
+ fmt.Println("Kilo authentication successful!")
+}
diff --git a/internal/cmd/kiro_login.go b/internal/cmd/kiro_login.go
new file mode 100644
index 00000000..74d09686
--- /dev/null
+++ b/internal/cmd/kiro_login.go
@@ -0,0 +1,208 @@
+package cmd
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
+ log "github.com/sirupsen/logrus"
+)
+
+// DoKiroLogin triggers the Kiro authentication flow with Google OAuth.
+// This is the default login method (same as --kiro-google-login).
+//
+// Parameters:
+// - cfg: The application configuration
+// - options: Login options including Prompt field
+func DoKiroLogin(cfg *config.Config, options *LoginOptions) {
+ // Use Google login as default
+ DoKiroGoogleLogin(cfg, options)
+}
+
+// DoKiroGoogleLogin triggers Kiro authentication with Google OAuth.
+// This uses a custom protocol handler (kiro://) to receive the callback.
+//
+// Parameters:
+// - cfg: The application configuration
+// - options: Login options including prompts
+func DoKiroGoogleLogin(cfg *config.Config, options *LoginOptions) {
+ if options == nil {
+ options = &LoginOptions{}
+ }
+
+ // Note: Kiro defaults to incognito mode for multi-account support.
+ // Users can override with --no-incognito if they want to use existing browser sessions.
+
+ manager := newAuthManager()
+
+ // Use KiroAuthenticator with Google login
+ authenticator := sdkAuth.NewKiroAuthenticator()
+ record, err := authenticator.LoginWithGoogle(context.Background(), cfg, &sdkAuth.LoginOptions{
+ NoBrowser: options.NoBrowser,
+ Metadata: map[string]string{},
+ Prompt: options.Prompt,
+ })
+ if err != nil {
+ log.Errorf("Kiro Google authentication failed: %v", err)
+ fmt.Println("\nTroubleshooting:")
+ fmt.Println("1. Make sure the protocol handler is installed")
+ fmt.Println("2. Complete the Google login in the browser")
+ fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)")
+ return
+ }
+
+ // Save the auth record
+ savedPath, err := manager.SaveAuth(record, cfg)
+ if err != nil {
+ log.Errorf("Failed to save auth: %v", err)
+ return
+ }
+
+ if savedPath != "" {
+ fmt.Printf("Authentication saved to %s\n", savedPath)
+ }
+ if record != nil && record.Label != "" {
+ fmt.Printf("Authenticated as %s\n", record.Label)
+ }
+ fmt.Println("Kiro Google authentication successful!")
+}
+
+// DoKiroAWSLogin triggers Kiro authentication with AWS Builder ID.
+// This uses the device code flow for AWS SSO OIDC authentication.
+//
+// Parameters:
+// - cfg: The application configuration
+// - options: Login options including prompts
+func DoKiroAWSLogin(cfg *config.Config, options *LoginOptions) {
+ if options == nil {
+ options = &LoginOptions{}
+ }
+
+ // Note: Kiro defaults to incognito mode for multi-account support.
+ // Users can override with --no-incognito if they want to use existing browser sessions.
+
+ manager := newAuthManager()
+
+ // Use KiroAuthenticator with AWS Builder ID login (device code flow)
+ authenticator := sdkAuth.NewKiroAuthenticator()
+ record, err := authenticator.Login(context.Background(), cfg, &sdkAuth.LoginOptions{
+ NoBrowser: options.NoBrowser,
+ Metadata: map[string]string{},
+ Prompt: options.Prompt,
+ })
+ if err != nil {
+ log.Errorf("Kiro AWS authentication failed: %v", err)
+ fmt.Println("\nTroubleshooting:")
+ fmt.Println("1. Make sure you have an AWS Builder ID")
+ fmt.Println("2. Complete the authorization in the browser")
+ fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)")
+ return
+ }
+
+ // Save the auth record
+ savedPath, err := manager.SaveAuth(record, cfg)
+ if err != nil {
+ log.Errorf("Failed to save auth: %v", err)
+ return
+ }
+
+ if savedPath != "" {
+ fmt.Printf("Authentication saved to %s\n", savedPath)
+ }
+ if record != nil && record.Label != "" {
+ fmt.Printf("Authenticated as %s\n", record.Label)
+ }
+ fmt.Println("Kiro AWS authentication successful!")
+}
+
+// DoKiroAWSAuthCodeLogin triggers Kiro authentication with AWS Builder ID using authorization code flow.
+// This provides a better UX than device code flow as it uses automatic browser callback.
+//
+// Parameters:
+// - cfg: The application configuration
+// - options: Login options including prompts
+func DoKiroAWSAuthCodeLogin(cfg *config.Config, options *LoginOptions) {
+ if options == nil {
+ options = &LoginOptions{}
+ }
+
+ // Note: Kiro defaults to incognito mode for multi-account support.
+ // Users can override with --no-incognito if they want to use existing browser sessions.
+
+ manager := newAuthManager()
+
+ // Use KiroAuthenticator with AWS Builder ID login (authorization code flow)
+ authenticator := sdkAuth.NewKiroAuthenticator()
+ record, err := authenticator.LoginWithAuthCode(context.Background(), cfg, &sdkAuth.LoginOptions{
+ NoBrowser: options.NoBrowser,
+ Metadata: map[string]string{},
+ Prompt: options.Prompt,
+ })
+ if err != nil {
+ log.Errorf("Kiro AWS authentication (auth code) failed: %v", err)
+ fmt.Println("\nTroubleshooting:")
+ fmt.Println("1. Make sure you have an AWS Builder ID")
+ fmt.Println("2. Complete the authorization in the browser")
+ fmt.Println("3. If callback fails, try: --kiro-aws-login (device code flow)")
+ return
+ }
+
+ // Save the auth record
+ savedPath, err := manager.SaveAuth(record, cfg)
+ if err != nil {
+ log.Errorf("Failed to save auth: %v", err)
+ return
+ }
+
+ if savedPath != "" {
+ fmt.Printf("Authentication saved to %s\n", savedPath)
+ }
+ if record != nil && record.Label != "" {
+ fmt.Printf("Authenticated as %s\n", record.Label)
+ }
+ fmt.Println("Kiro AWS authentication successful!")
+}
+
+// DoKiroImport imports Kiro token from Kiro IDE's token file.
+// This is useful for users who have already logged in via Kiro IDE
+// and want to use the same credentials in CLI Proxy API.
+//
+// Parameters:
+// - cfg: The application configuration
+// - options: Login options (currently unused for import)
+func DoKiroImport(cfg *config.Config, options *LoginOptions) {
+ if options == nil {
+ options = &LoginOptions{}
+ }
+
+ manager := newAuthManager()
+
+ // Use ImportFromKiroIDE instead of Login
+ authenticator := sdkAuth.NewKiroAuthenticator()
+ record, err := authenticator.ImportFromKiroIDE(context.Background(), cfg)
+ if err != nil {
+ log.Errorf("Kiro token import failed: %v", err)
+ fmt.Println("\nMake sure you have logged in to Kiro IDE first:")
+ fmt.Println("1. Open Kiro IDE")
+ fmt.Println("2. Click 'Sign in with Google' (or GitHub)")
+ fmt.Println("3. Complete the login process")
+ fmt.Println("4. Run this command again")
+ return
+ }
+
+ // Save the imported auth record
+ savedPath, err := manager.SaveAuth(record, cfg)
+ if err != nil {
+ log.Errorf("Failed to save auth: %v", err)
+ return
+ }
+
+ if savedPath != "" {
+ fmt.Printf("Authentication saved to %s\n", savedPath)
+ }
+ if record != nil && record.Label != "" {
+ fmt.Printf("Imported as %s\n", record.Label)
+ }
+ fmt.Println("Kiro token import successful!")
+}
diff --git a/internal/config/config.go b/internal/config/config.go
index 5b18f3df..e28483d5 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -84,6 +84,13 @@ type Config struct {
// GeminiKey defines Gemini API key configurations with optional routing overrides.
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
+ // KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations.
+ KiroKey []KiroKey `yaml:"kiro" json:"kiro"`
+
+ // KiroPreferredEndpoint sets the global default preferred endpoint for all Kiro providers.
+ // Values: "ide" (default, CodeWhisperer) or "cli" (Amazon Q).
+ KiroPreferredEndpoint string `yaml:"kiro-preferred-endpoint" json:"kiro-preferred-endpoint"`
+
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
@@ -105,11 +112,12 @@ type Config struct {
AmpCode AmpCode `yaml:"ampcode" json:"ampcode"`
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
+ // Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
// These aliases affect both model listing and model routing for supported channels:
- // gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
+ // gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot.
//
// NOTE: This does not apply to existing per-credential model alias features under:
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
@@ -118,6 +126,11 @@ type Config struct {
// Payload defines default and override rules for provider payload parameters.
Payload PayloadConfig `yaml:"payload" json:"payload"`
+ // IncognitoBrowser enables opening OAuth URLs in incognito/private browsing mode.
+ // This is useful when you want to login with a different account without logging out
+ // from your current session. Default: false.
+ IncognitoBrowser bool `yaml:"incognito-browser" json:"incognito-browser"`
+
legacyMigrationPending bool `yaml:"-" json:"-"`
}
@@ -443,6 +456,35 @@ type GeminiModel struct {
func (m GeminiModel) GetName() string { return m.Name }
func (m GeminiModel) GetAlias() string { return m.Alias }
+// KiroKey represents the configuration for Kiro (AWS CodeWhisperer) authentication.
+type KiroKey struct {
+ // TokenFile is the path to the Kiro token file (default: ~/.aws/sso/cache/kiro-auth-token.json)
+ TokenFile string `yaml:"token-file,omitempty" json:"token-file,omitempty"`
+
+ // AccessToken is the OAuth access token for direct configuration.
+ AccessToken string `yaml:"access-token,omitempty" json:"access-token,omitempty"`
+
+ // RefreshToken is the OAuth refresh token for token renewal.
+ RefreshToken string `yaml:"refresh-token,omitempty" json:"refresh-token,omitempty"`
+
+ // ProfileArn is the AWS CodeWhisperer profile ARN.
+ ProfileArn string `yaml:"profile-arn,omitempty" json:"profile-arn,omitempty"`
+
+ // Region is the AWS region (default: us-east-1).
+ Region string `yaml:"region,omitempty" json:"region,omitempty"`
+
+ // ProxyURL optionally overrides the global proxy for this configuration.
+ ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
+
+ // AgentTaskType sets the Kiro API task type. Known values: "vibe", "dev", "chat".
+ // Leave empty to let API use defaults. Different values may inject different system prompts.
+ AgentTaskType string `yaml:"agent-task-type,omitempty" json:"agent-task-type,omitempty"`
+
+ // PreferredEndpoint sets the preferred Kiro API endpoint/quota.
+ // Values: "codewhisperer" (default, IDE quota) or "amazonq" (CLI quota).
+ PreferredEndpoint string `yaml:"preferred-endpoint,omitempty" json:"preferred-endpoint,omitempty"`
+}
+
// OpenAICompatibility represents the configuration for OpenAI API compatibility
// with external providers, allowing model aliases to be routed through OpenAI API format.
type OpenAICompatibility struct {
@@ -549,6 +591,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.Pprof.Addr = DefaultPprofAddr
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
+ cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
if err = yaml.Unmarshal(data, &cfg); err != nil {
if optional {
// In cloud deploy mode, if YAML parsing fails, return empty config instead of error.
@@ -617,6 +660,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Sanitize Claude key headers
cfg.SanitizeClaudeKeys()
+ // Sanitize Kiro keys: trim whitespace from credential fields
+ cfg.SanitizeKiroKeys()
+
// Sanitize OpenAI compatibility providers: drop entries without base-url
cfg.SanitizeOpenAICompatibility()
@@ -706,14 +752,44 @@ func payloadRawString(value any) ([]byte, bool) {
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
+// It also injects default aliases for channels that have built-in defaults (e.g., kiro)
+// when no user-configured aliases exist for those channels.
func (cfg *Config) SanitizeOAuthModelAlias() {
- if cfg == nil || len(cfg.OAuthModelAlias) == 0 {
+ if cfg == nil {
+ return
+ }
+
+ // Inject default Kiro aliases if no user-configured kiro aliases exist
+ if cfg.OAuthModelAlias == nil {
+ cfg.OAuthModelAlias = make(map[string][]OAuthModelAlias)
+ }
+ if _, hasKiro := cfg.OAuthModelAlias["kiro"]; !hasKiro {
+ // Check case-insensitive too
+ found := false
+ for k := range cfg.OAuthModelAlias {
+ if strings.EqualFold(strings.TrimSpace(k), "kiro") {
+ found = true
+ break
+ }
+ }
+ if !found {
+ cfg.OAuthModelAlias["kiro"] = defaultKiroAliases()
+ }
+ }
+
+ if len(cfg.OAuthModelAlias) == 0 {
return
}
out := make(map[string][]OAuthModelAlias, len(cfg.OAuthModelAlias))
for rawChannel, aliases := range cfg.OAuthModelAlias {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
- if channel == "" || len(aliases) == 0 {
+ if channel == "" {
+ continue
+ }
+ // Preserve channels that were explicitly set to empty/nil – they act
+ // as "disabled" markers so default injection won't re-add them (#222).
+ if len(aliases) == 0 {
+ out[channel] = nil
continue
}
seenAlias := make(map[string]struct{}, len(aliases))
@@ -798,6 +874,23 @@ func (cfg *Config) SanitizeClaudeKeys() {
}
}
+// SanitizeKiroKeys trims whitespace from Kiro credential fields.
+func (cfg *Config) SanitizeKiroKeys() {
+ if cfg == nil || len(cfg.KiroKey) == 0 {
+ return
+ }
+ for i := range cfg.KiroKey {
+ entry := &cfg.KiroKey[i]
+ entry.TokenFile = strings.TrimSpace(entry.TokenFile)
+ entry.AccessToken = strings.TrimSpace(entry.AccessToken)
+ entry.RefreshToken = strings.TrimSpace(entry.RefreshToken)
+ entry.ProfileArn = strings.TrimSpace(entry.ProfileArn)
+ entry.Region = strings.TrimSpace(entry.Region)
+ entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
+ entry.PreferredEndpoint = strings.TrimSpace(entry.PreferredEndpoint)
+ }
+}
+
// SanitizeGeminiKeys deduplicates and normalizes Gemini credentials.
func (cfg *Config) SanitizeGeminiKeys() {
if cfg == nil {
diff --git a/internal/config/oauth_model_alias_migration.go b/internal/config/oauth_model_alias_migration.go
index f52df27a..639cbccd 100644
--- a/internal/config/oauth_model_alias_migration.go
+++ b/internal/config/oauth_model_alias_migration.go
@@ -20,6 +20,28 @@ var antigravityModelConversionTable = map[string]string{
"gemini-claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
}
+// defaultKiroAliases returns the default oauth-model-alias configuration
+// for the kiro channel. Maps kiro-prefixed model names to standard Claude model
+// names so that clients like Claude Code can use standard names directly.
+func defaultKiroAliases() []OAuthModelAlias {
+ return []OAuthModelAlias{
+ // Sonnet 4.5
+ {Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5-20250929", Fork: true},
+ {Name: "kiro-claude-sonnet-4-5", Alias: "claude-sonnet-4-5", Fork: true},
+ // Sonnet 4
+ {Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4-20250514", Fork: true},
+ {Name: "kiro-claude-sonnet-4", Alias: "claude-sonnet-4", Fork: true},
+ // Opus 4.6
+ {Name: "kiro-claude-opus-4-6", Alias: "claude-opus-4-6", Fork: true},
+ // Opus 4.5
+ {Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5-20251101", Fork: true},
+ {Name: "kiro-claude-opus-4-5", Alias: "claude-opus-4-5", Fork: true},
+ // Haiku 4.5
+ {Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5-20251001", Fork: true},
+ {Name: "kiro-claude-haiku-4-5", Alias: "claude-haiku-4-5", Fork: true},
+ }
+}
+
// defaultAntigravityAliases returns the default oauth-model-alias configuration
// for the antigravity channel when neither field exists.
func defaultAntigravityAliases() []OAuthModelAlias {
diff --git a/internal/config/oauth_model_alias_test.go b/internal/config/oauth_model_alias_test.go
index a5886474..5cf05502 100644
--- a/internal/config/oauth_model_alias_test.go
+++ b/internal/config/oauth_model_alias_test.go
@@ -54,3 +54,132 @@ func TestSanitizeOAuthModelAlias_AllowsMultipleAliasesForSameName(t *testing.T)
}
}
}
+
+func TestSanitizeOAuthModelAlias_InjectsDefaultKiroAliases(t *testing.T) {
+ // When no kiro aliases are configured, defaults should be injected
+ cfg := &Config{
+ OAuthModelAlias: map[string][]OAuthModelAlias{
+ "codex": {
+ {Name: "gpt-5", Alias: "g5"},
+ },
+ },
+ }
+
+ cfg.SanitizeOAuthModelAlias()
+
+ kiroAliases := cfg.OAuthModelAlias["kiro"]
+ if len(kiroAliases) == 0 {
+ t.Fatal("expected default kiro aliases to be injected")
+ }
+
+ // Check that standard Claude model names are present
+ aliasSet := make(map[string]bool)
+ for _, a := range kiroAliases {
+ aliasSet[a.Alias] = true
+ }
+ expectedAliases := []string{
+ "claude-sonnet-4-5-20250929",
+ "claude-sonnet-4-5",
+ "claude-sonnet-4-20250514",
+ "claude-sonnet-4",
+ "claude-opus-4-6",
+ "claude-opus-4-5-20251101",
+ "claude-opus-4-5",
+ "claude-haiku-4-5-20251001",
+ "claude-haiku-4-5",
+ }
+ for _, expected := range expectedAliases {
+ if !aliasSet[expected] {
+ t.Fatalf("expected default kiro alias %q to be present", expected)
+ }
+ }
+
+ // All should have fork=true
+ for _, a := range kiroAliases {
+ if !a.Fork {
+ t.Fatalf("expected all default kiro aliases to have fork=true, got fork=false for %q", a.Alias)
+ }
+ }
+
+ // Codex aliases should still be preserved
+ if len(cfg.OAuthModelAlias["codex"]) != 1 {
+ t.Fatal("expected codex aliases to be preserved")
+ }
+}
+
+func TestSanitizeOAuthModelAlias_DoesNotOverrideUserKiroAliases(t *testing.T) {
+ // When user has configured kiro aliases, defaults should NOT be injected
+ cfg := &Config{
+ OAuthModelAlias: map[string][]OAuthModelAlias{
+ "kiro": {
+ {Name: "kiro-claude-sonnet-4", Alias: "my-custom-sonnet", Fork: true},
+ },
+ },
+ }
+
+ cfg.SanitizeOAuthModelAlias()
+
+ kiroAliases := cfg.OAuthModelAlias["kiro"]
+ if len(kiroAliases) != 1 {
+ t.Fatalf("expected 1 user-configured kiro alias, got %d", len(kiroAliases))
+ }
+ if kiroAliases[0].Alias != "my-custom-sonnet" {
+ t.Fatalf("expected user alias to be preserved, got %q", kiroAliases[0].Alias)
+ }
+}
+
+func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletion(t *testing.T) {
+ // When user explicitly deletes kiro aliases (key exists with nil value),
+ // defaults should NOT be re-injected on subsequent sanitize calls (#222).
+ cfg := &Config{
+ OAuthModelAlias: map[string][]OAuthModelAlias{
+ "kiro": nil, // explicitly deleted
+ "codex": {{Name: "gpt-5", Alias: "g5"}},
+ },
+ }
+
+ cfg.SanitizeOAuthModelAlias()
+
+ kiroAliases := cfg.OAuthModelAlias["kiro"]
+ if len(kiroAliases) != 0 {
+ t.Fatalf("expected kiro aliases to remain empty after explicit deletion, got %d aliases", len(kiroAliases))
+ }
+ // The key itself must still be present to prevent re-injection on next reload
+ if _, exists := cfg.OAuthModelAlias["kiro"]; !exists {
+ t.Fatal("expected kiro key to be preserved as nil marker after sanitization")
+ }
+ // Other channels should be unaffected
+ if len(cfg.OAuthModelAlias["codex"]) != 1 {
+ t.Fatal("expected codex aliases to be preserved")
+ }
+}
+
+func TestSanitizeOAuthModelAlias_DoesNotReinjectAfterExplicitDeletionEmpty(t *testing.T) {
+ // Same as above but with empty slice instead of nil (PUT with empty body).
+ cfg := &Config{
+ OAuthModelAlias: map[string][]OAuthModelAlias{
+ "kiro": {}, // explicitly set to empty
+ },
+ }
+
+ cfg.SanitizeOAuthModelAlias()
+
+ if len(cfg.OAuthModelAlias["kiro"]) != 0 {
+ t.Fatalf("expected kiro aliases to remain empty, got %d aliases", len(cfg.OAuthModelAlias["kiro"]))
+ }
+ if _, exists := cfg.OAuthModelAlias["kiro"]; !exists {
+ t.Fatal("expected kiro key to be preserved")
+ }
+}
+
+func TestSanitizeOAuthModelAlias_InjectsDefaultKiroWhenEmpty(t *testing.T) {
+ // When OAuthModelAlias is nil, kiro defaults should still be injected
+ cfg := &Config{}
+
+ cfg.SanitizeOAuthModelAlias()
+
+ kiroAliases := cfg.OAuthModelAlias["kiro"]
+ if len(kiroAliases) == 0 {
+ t.Fatal("expected default kiro aliases to be injected when OAuthModelAlias is nil")
+ }
+}
diff --git a/internal/constant/constant.go b/internal/constant/constant.go
index 58b388a1..9b7d31aa 100644
--- a/internal/constant/constant.go
+++ b/internal/constant/constant.go
@@ -24,4 +24,10 @@ const (
// Antigravity represents the Antigravity response format identifier.
Antigravity = "antigravity"
+
+ // Kiro represents the AWS CodeWhisperer (Kiro) provider identifier.
+ Kiro = "kiro"
+
+ // Kilo represents the Kilo AI provider identifier.
+ Kilo = "kilo"
)
diff --git a/internal/logging/global_logger.go b/internal/logging/global_logger.go
index 372222a5..484ecba7 100644
--- a/internal/logging/global_logger.go
+++ b/internal/logging/global_logger.go
@@ -85,6 +85,7 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
func SetupBaseLogger() {
setupOnce.Do(func() {
log.SetOutput(os.Stdout)
+ log.SetLevel(log.InfoLevel)
log.SetReportCaller(true)
log.SetFormatter(&LogFormatter{})
diff --git a/internal/registry/kilo_models.go b/internal/registry/kilo_models.go
new file mode 100644
index 00000000..ac9939db
--- /dev/null
+++ b/internal/registry/kilo_models.go
@@ -0,0 +1,21 @@
+// Package registry provides model definitions for various AI service providers.
+package registry
+
+// GetKiloModels returns the Kilo model definitions
+func GetKiloModels() []*ModelInfo {
+ return []*ModelInfo{
+ // --- Base Models ---
+ {
+ ID: "kilo/auto",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "kilo",
+ Type: "kilo",
+ DisplayName: "Kilo Auto",
+ Description: "Automatic model selection by Kilo",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ }
+}
diff --git a/internal/registry/kiro_model_converter.go b/internal/registry/kiro_model_converter.go
new file mode 100644
index 00000000..fe50a8f3
--- /dev/null
+++ b/internal/registry/kiro_model_converter.go
@@ -0,0 +1,303 @@
+// Package registry provides Kiro model conversion utilities.
+// This file handles converting dynamic Kiro API model lists to the internal ModelInfo format,
+// and merging with static metadata for thinking support and other capabilities.
+package registry
+
+import (
+ "strings"
+ "time"
+)
+
+// KiroAPIModel represents a model from Kiro API response.
+// This is a local copy to avoid import cycles with the kiro package.
+// The structure mirrors kiro.KiroModel for easy data conversion.
+type KiroAPIModel struct {
+ // ModelID is the unique identifier for the model (e.g., "claude-sonnet-4.5")
+ ModelID string
+ // ModelName is the human-readable name
+ ModelName string
+ // Description is the model description
+ Description string
+ // RateMultiplier is the credit multiplier for this model
+ RateMultiplier float64
+ // RateUnit is the unit for rate calculation (e.g., "credit")
+ RateUnit string
+ // MaxInputTokens is the maximum input token limit
+ MaxInputTokens int
+}
+
+// DefaultKiroThinkingSupport defines the default thinking configuration for Kiro models.
+// All Kiro models support thinking with the following budget range.
+var DefaultKiroThinkingSupport = &ThinkingSupport{
+ Min: 1024, // Minimum thinking budget tokens
+ Max: 32000, // Maximum thinking budget tokens
+ ZeroAllowed: true, // Allow disabling thinking with 0
+ DynamicAllowed: true, // Allow dynamic thinking budget (-1)
+}
+
+// DefaultKiroContextLength is the default context window size for Kiro models.
+const DefaultKiroContextLength = 200000
+
+// DefaultKiroMaxCompletionTokens is the default max completion tokens for Kiro models.
+const DefaultKiroMaxCompletionTokens = 64000
+
+// ConvertKiroAPIModels converts Kiro API models to internal ModelInfo format.
+// It performs the following transformations:
+// - Normalizes model ID (e.g., claude-sonnet-4.5 → kiro-claude-sonnet-4-5)
+// - Adds default thinking support metadata
+// - Sets default context length and max completion tokens if not provided
+//
+// Parameters:
+// - kiroModels: List of models from Kiro API response
+//
+// Returns:
+// - []*ModelInfo: Converted model information list
+func ConvertKiroAPIModels(kiroModels []*KiroAPIModel) []*ModelInfo {
+ if len(kiroModels) == 0 {
+ return nil
+ }
+
+ now := time.Now().Unix()
+ result := make([]*ModelInfo, 0, len(kiroModels))
+
+ for _, km := range kiroModels {
+ // Skip nil models
+ if km == nil {
+ continue
+ }
+
+ // Skip models without valid ID
+ if km.ModelID == "" {
+ continue
+ }
+
+ // Normalize the model ID to kiro-* format
+ normalizedID := normalizeKiroModelID(km.ModelID)
+
+ // Create ModelInfo with converted data
+ info := &ModelInfo{
+ ID: normalizedID,
+ Object: "model",
+ Created: now,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: generateKiroDisplayName(km.ModelName, normalizedID),
+ Description: km.Description,
+ // Use MaxInputTokens from API if available, otherwise use default
+ ContextLength: getContextLength(km.MaxInputTokens),
+ MaxCompletionTokens: DefaultKiroMaxCompletionTokens,
+ // All Kiro models support thinking
+ Thinking: cloneThinkingSupport(DefaultKiroThinkingSupport),
+ }
+
+ result = append(result, info)
+ }
+
+ return result
+}
+
+// GenerateAgenticVariants creates -agentic variants for each model.
+// Agentic variants are optimized for coding agents with chunked writes.
+//
+// Parameters:
+// - models: Base models to generate variants for
+//
+// Returns:
+// - []*ModelInfo: Combined list of base models and their agentic variants
+func GenerateAgenticVariants(models []*ModelInfo) []*ModelInfo {
+ if len(models) == 0 {
+ return nil
+ }
+
+ // Pre-allocate result with capacity for both base models and variants
+ result := make([]*ModelInfo, 0, len(models)*2)
+
+ for _, model := range models {
+ if model == nil {
+ continue
+ }
+
+ // Add the base model first
+ result = append(result, model)
+
+ // Skip if model already has -agentic suffix
+ if strings.HasSuffix(model.ID, "-agentic") {
+ continue
+ }
+
+ // Skip special models that shouldn't have agentic variants
+ if model.ID == "kiro-auto" {
+ continue
+ }
+
+ // Create agentic variant
+ agenticModel := &ModelInfo{
+ ID: model.ID + "-agentic",
+ Object: model.Object,
+ Created: model.Created,
+ OwnedBy: model.OwnedBy,
+ Type: model.Type,
+ DisplayName: model.DisplayName + " (Agentic)",
+ Description: generateAgenticDescription(model.Description),
+ ContextLength: model.ContextLength,
+ MaxCompletionTokens: model.MaxCompletionTokens,
+ Thinking: cloneThinkingSupport(model.Thinking),
+ }
+
+ result = append(result, agenticModel)
+ }
+
+ return result
+}
+
+// MergeWithStaticMetadata merges dynamic models with static metadata.
+// Static metadata takes priority for any overlapping fields.
+// This allows manual overrides for specific models while keeping dynamic discovery.
+//
+// Parameters:
+// - dynamicModels: Models from Kiro API (converted to ModelInfo)
+// - staticModels: Predefined model metadata (from GetKiroModels())
+//
+// Returns:
+// - []*ModelInfo: Merged model list with static metadata taking priority
+func MergeWithStaticMetadata(dynamicModels, staticModels []*ModelInfo) []*ModelInfo {
+ if len(dynamicModels) == 0 && len(staticModels) == 0 {
+ return nil
+ }
+
+ // Build a map of static models for quick lookup
+ staticMap := make(map[string]*ModelInfo, len(staticModels))
+ for _, sm := range staticModels {
+ if sm != nil && sm.ID != "" {
+ staticMap[sm.ID] = sm
+ }
+ }
+
+ // Build result, preferring static metadata where available
+ seenIDs := make(map[string]struct{})
+ result := make([]*ModelInfo, 0, len(dynamicModels)+len(staticModels))
+
+ // First, process dynamic models and merge with static if available
+ for _, dm := range dynamicModels {
+ if dm == nil || dm.ID == "" {
+ continue
+ }
+
+ // Skip duplicates
+ if _, seen := seenIDs[dm.ID]; seen {
+ continue
+ }
+ seenIDs[dm.ID] = struct{}{}
+
+ // Check if static metadata exists for this model
+ if sm, exists := staticMap[dm.ID]; exists {
+ // Static metadata takes priority - use static model
+ result = append(result, sm)
+ } else {
+ // No static metadata - use dynamic model
+ result = append(result, dm)
+ }
+ }
+
+ // Add any static models not in dynamic list
+ for _, sm := range staticModels {
+ if sm == nil || sm.ID == "" {
+ continue
+ }
+ if _, seen := seenIDs[sm.ID]; seen {
+ continue
+ }
+ seenIDs[sm.ID] = struct{}{}
+ result = append(result, sm)
+ }
+
+ return result
+}
+
+// normalizeKiroModelID converts Kiro API model IDs to internal format.
+// Transformation rules:
+// - Adds "kiro-" prefix if not present
+// - Replaces dots with hyphens (e.g., 4.5 → 4-5)
+// - Handles special cases like "auto" → "kiro-auto"
+//
+// Examples:
+// - "claude-sonnet-4.5" → "kiro-claude-sonnet-4-5"
+// - "claude-opus-4.5" → "kiro-claude-opus-4-5"
+// - "auto" → "kiro-auto"
+// - "kiro-claude-sonnet-4-5" → "kiro-claude-sonnet-4-5" (unchanged)
+func normalizeKiroModelID(modelID string) string {
+ if modelID == "" {
+ return ""
+ }
+
+ // Trim whitespace
+ modelID = strings.TrimSpace(modelID)
+
+ // Replace dots with hyphens (e.g., 4.5 → 4-5)
+ normalized := strings.ReplaceAll(modelID, ".", "-")
+
+ // Add kiro- prefix if not present
+ if !strings.HasPrefix(normalized, "kiro-") {
+ normalized = "kiro-" + normalized
+ }
+
+ return normalized
+}
+
+// generateKiroDisplayName creates a human-readable display name.
+// Uses the API-provided model name if available, otherwise generates from ID.
+func generateKiroDisplayName(modelName, normalizedID string) string {
+ if modelName != "" {
+ return "Kiro " + modelName
+ }
+
+ // Generate from normalized ID by removing kiro- prefix and formatting
+ displayID := strings.TrimPrefix(normalizedID, "kiro-")
+ // Capitalize first letter of each word
+ words := strings.Split(displayID, "-")
+ for i, word := range words {
+ if len(word) > 0 {
+ words[i] = strings.ToUpper(word[:1]) + word[1:]
+ }
+ }
+ return "Kiro " + strings.Join(words, " ")
+}
+
+// generateAgenticDescription creates description for agentic variants.
+func generateAgenticDescription(baseDescription string) string {
+ if baseDescription == "" {
+ return "Optimized for coding agents with chunked writes"
+ }
+ return baseDescription + " (Agentic mode: chunked writes)"
+}
+
+// getContextLength returns the context length, using default if not provided.
+func getContextLength(maxInputTokens int) int {
+ if maxInputTokens > 0 {
+ return maxInputTokens
+ }
+ return DefaultKiroContextLength
+}
+
+// cloneThinkingSupport creates a deep copy of ThinkingSupport.
+// Returns nil if input is nil.
+func cloneThinkingSupport(ts *ThinkingSupport) *ThinkingSupport {
+ if ts == nil {
+ return nil
+ }
+
+ clone := &ThinkingSupport{
+ Min: ts.Min,
+ Max: ts.Max,
+ ZeroAllowed: ts.ZeroAllowed,
+ DynamicAllowed: ts.DynamicAllowed,
+ }
+
+ // Deep copy Levels slice if present
+ if len(ts.Levels) > 0 {
+ clone.Levels = make([]string, len(ts.Levels))
+ copy(clone.Levels, ts.Levels)
+ }
+
+ return clone
+}
diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go
index c1796979..cf6b3f09 100644
--- a/internal/registry/model_definitions.go
+++ b/internal/registry/model_definitions.go
@@ -20,6 +20,11 @@ import (
// - qwen
// - iflow
// - kimi
+// - kiro
+// - kilo
+// - github-copilot
+// - kiro
+// - amazonq
// - antigravity (returns static overrides only)
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
key := strings.ToLower(strings.TrimSpace(channel))
@@ -42,6 +47,14 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
return GetIFlowModels()
case "kimi":
return GetKimiModels()
+ case "github-copilot":
+ return GetGitHubCopilotModels()
+ case "kiro":
+ return GetKiroModels()
+ case "kilo":
+ return GetKiloModels()
+ case "amazonq":
+ return GetAmazonQModels()
case "antigravity":
cfg := GetAntigravityModelConfig()
if len(cfg) == 0 {
@@ -87,6 +100,10 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
GetQwenModels(),
GetIFlowModels(),
GetKimiModels(),
+ GetGitHubCopilotModels(),
+ GetKiroModels(),
+ GetKiloModels(),
+ GetAmazonQModels(),
}
for _, models := range allModels {
for _, m := range models {
@@ -107,3 +124,599 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
return nil
}
+
+// GetGitHubCopilotModels returns the available models for GitHub Copilot.
+// These models are available through the GitHub Copilot API at api.githubcopilot.com.
+func GetGitHubCopilotModels() []*ModelInfo {
+ now := int64(1732752000) // 2024-11-27
+ return []*ModelInfo{
+ {
+ ID: "gpt-4.1",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "GPT-4.1",
+ Description: "OpenAI GPT-4.1 via GitHub Copilot",
+ ContextLength: 128000,
+ MaxCompletionTokens: 16384,
+ },
+ {
+ ID: "gpt-5",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "GPT-5",
+ Description: "OpenAI GPT-5 via GitHub Copilot",
+ ContextLength: 200000,
+ MaxCompletionTokens: 32768,
+ SupportedEndpoints: []string{"/chat/completions", "/responses"},
+ Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
+ },
+ {
+ ID: "gpt-5-mini",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "GPT-5 Mini",
+ Description: "OpenAI GPT-5 Mini via GitHub Copilot",
+ ContextLength: 128000,
+ MaxCompletionTokens: 16384,
+ SupportedEndpoints: []string{"/chat/completions", "/responses"},
+ Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
+ },
+ {
+ ID: "gpt-5-codex",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "GPT-5 Codex",
+ Description: "OpenAI GPT-5 Codex via GitHub Copilot",
+ ContextLength: 200000,
+ MaxCompletionTokens: 32768,
+ SupportedEndpoints: []string{"/responses"},
+ Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}},
+ },
+ {
+ ID: "gpt-5.1",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "GPT-5.1",
+ Description: "OpenAI GPT-5.1 via GitHub Copilot",
+ ContextLength: 200000,
+ MaxCompletionTokens: 32768,
+ SupportedEndpoints: []string{"/chat/completions", "/responses"},
+ Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
+ },
+ {
+ ID: "gpt-5.1-codex",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "GPT-5.1 Codex",
+ Description: "OpenAI GPT-5.1 Codex via GitHub Copilot",
+ ContextLength: 200000,
+ MaxCompletionTokens: 32768,
+ SupportedEndpoints: []string{"/responses"},
+ Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
+ },
+ {
+ ID: "gpt-5.1-codex-mini",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "GPT-5.1 Codex Mini",
+ Description: "OpenAI GPT-5.1 Codex Mini via GitHub Copilot",
+ ContextLength: 128000,
+ MaxCompletionTokens: 16384,
+ SupportedEndpoints: []string{"/responses"},
+ Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}},
+ },
+ {
+ ID: "gpt-5.1-codex-max",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "GPT-5.1 Codex Max",
+ Description: "OpenAI GPT-5.1 Codex Max via GitHub Copilot",
+ ContextLength: 200000,
+ MaxCompletionTokens: 32768,
+ SupportedEndpoints: []string{"/responses"},
+ Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
+ },
+ {
+ ID: "gpt-5.2",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "GPT-5.2",
+ Description: "OpenAI GPT-5.2 via GitHub Copilot",
+ ContextLength: 200000,
+ MaxCompletionTokens: 32768,
+ SupportedEndpoints: []string{"/chat/completions", "/responses"},
+ Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
+ },
+ {
+ ID: "gpt-5.2-codex",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "GPT-5.2 Codex",
+ Description: "OpenAI GPT-5.2 Codex via GitHub Copilot",
+ ContextLength: 200000,
+ MaxCompletionTokens: 32768,
+ SupportedEndpoints: []string{"/responses"},
+ Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
+ },
+ {
+ ID: "gpt-5.3-codex",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "GPT-5.3 Codex",
+ Description: "OpenAI GPT-5.3 Codex via GitHub Copilot",
+ ContextLength: 200000,
+ MaxCompletionTokens: 32768,
+ SupportedEndpoints: []string{"/responses"},
+ Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
+ },
+ {
+ ID: "claude-haiku-4.5",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "Claude Haiku 4.5",
+ Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ SupportedEndpoints: []string{"/chat/completions"},
+ },
+ {
+ ID: "claude-opus-4.1",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "Claude Opus 4.1",
+ Description: "Anthropic Claude Opus 4.1 via GitHub Copilot",
+ ContextLength: 200000,
+ MaxCompletionTokens: 32000,
+ SupportedEndpoints: []string{"/chat/completions"},
+ },
+ {
+ ID: "claude-opus-4.5",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "Claude Opus 4.5",
+ Description: "Anthropic Claude Opus 4.5 via GitHub Copilot",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ SupportedEndpoints: []string{"/chat/completions"},
+ },
+ {
+ ID: "claude-opus-4.6",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "Claude Opus 4.6",
+ Description: "Anthropic Claude Opus 4.6 via GitHub Copilot",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ SupportedEndpoints: []string{"/chat/completions"},
+ },
+ {
+ ID: "claude-sonnet-4",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "Claude Sonnet 4",
+ Description: "Anthropic Claude Sonnet 4 via GitHub Copilot",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ SupportedEndpoints: []string{"/chat/completions"},
+ },
+ {
+ ID: "claude-sonnet-4.5",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "Claude Sonnet 4.5",
+ Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ SupportedEndpoints: []string{"/chat/completions"},
+ },
+ {
+ ID: "claude-sonnet-4.6",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "Claude Sonnet 4.6",
+ Description: "Anthropic Claude Sonnet 4.6 via GitHub Copilot",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ SupportedEndpoints: []string{"/chat/completions"},
+ },
+ {
+ ID: "gemini-2.5-pro",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "Gemini 2.5 Pro",
+ Description: "Google Gemini 2.5 Pro via GitHub Copilot",
+ ContextLength: 1048576,
+ MaxCompletionTokens: 65536,
+ },
+ {
+ ID: "gemini-3-pro-preview",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "Gemini 3 Pro (Preview)",
+ Description: "Google Gemini 3 Pro Preview via GitHub Copilot",
+ ContextLength: 1048576,
+ MaxCompletionTokens: 65536,
+ },
+ {
+ ID: "gemini-3-flash-preview",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "Gemini 3 Flash (Preview)",
+ Description: "Google Gemini 3 Flash Preview via GitHub Copilot",
+ ContextLength: 1048576,
+ MaxCompletionTokens: 65536,
+ },
+ {
+ ID: "grok-code-fast-1",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "Grok Code Fast 1",
+ Description: "xAI Grok Code Fast 1 via GitHub Copilot",
+ ContextLength: 128000,
+ MaxCompletionTokens: 16384,
+ },
+ {
+ ID: "oswe-vscode-prime",
+ Object: "model",
+ Created: now,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "Raptor mini (Preview)",
+ Description: "Raptor mini via GitHub Copilot",
+ ContextLength: 128000,
+ MaxCompletionTokens: 16384,
+ SupportedEndpoints: []string{"/chat/completions", "/responses"},
+ },
+ }
+}
+
+// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions
+func GetKiroModels() []*ModelInfo {
+ return []*ModelInfo{
+ // --- Base Models ---
+ {
+ ID: "kiro-auto",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro Auto",
+ Description: "Automatic model selection by Kiro",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ {
+ ID: "kiro-claude-opus-4-6",
+ Object: "model",
+ Created: 1736899200, // 2025-01-15
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro Claude Opus 4.6",
+ Description: "Claude Opus 4.6 via Kiro (2.2x credit)",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ {
+ ID: "kiro-claude-sonnet-4-6",
+ Object: "model",
+ Created: 1739836800, // 2025-02-18
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro Claude Sonnet 4.6",
+ Description: "Claude Sonnet 4.6 via Kiro (1.3x credit)",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ {
+ ID: "kiro-claude-opus-4-5",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro Claude Opus 4.5",
+ Description: "Claude Opus 4.5 via Kiro (2.2x credit)",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ {
+ ID: "kiro-claude-sonnet-4-5",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro Claude Sonnet 4.5",
+ Description: "Claude Sonnet 4.5 via Kiro (1.3x credit)",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ {
+ ID: "kiro-claude-sonnet-4",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro Claude Sonnet 4",
+ Description: "Claude Sonnet 4 via Kiro (1.3x credit)",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ {
+ ID: "kiro-claude-haiku-4-5",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro Claude Haiku 4.5",
+ Description: "Claude Haiku 4.5 via Kiro (0.4x credit)",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ // --- 第三方模型 (通过 Kiro 接入) ---
+ {
+ ID: "kiro-deepseek-3-2",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro DeepSeek 3.2",
+ Description: "DeepSeek 3.2 via Kiro",
+ ContextLength: 128000,
+ MaxCompletionTokens: 32768,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ {
+ ID: "kiro-minimax-m2-1",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro MiniMax M2.1",
+ Description: "MiniMax M2.1 via Kiro",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ {
+ ID: "kiro-qwen3-coder-next",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro Qwen3 Coder Next",
+ Description: "Qwen3 Coder Next via Kiro",
+ ContextLength: 128000,
+ MaxCompletionTokens: 32768,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ {
+ ID: "kiro-gpt-4o",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro GPT-4o",
+ Description: "OpenAI GPT-4o via Kiro",
+ ContextLength: 128000,
+ MaxCompletionTokens: 16384,
+ },
+ {
+ ID: "kiro-gpt-4",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro GPT-4",
+ Description: "OpenAI GPT-4 via Kiro",
+ ContextLength: 128000,
+ MaxCompletionTokens: 8192,
+ },
+ {
+ ID: "kiro-gpt-4-turbo",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro GPT-4 Turbo",
+ Description: "OpenAI GPT-4 Turbo via Kiro",
+ ContextLength: 128000,
+ MaxCompletionTokens: 16384,
+ },
+ {
+ ID: "kiro-gpt-3-5-turbo",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro GPT-3.5 Turbo",
+ Description: "OpenAI GPT-3.5 Turbo via Kiro",
+ ContextLength: 16384,
+ MaxCompletionTokens: 4096,
+ },
+ // --- Agentic Variants (Optimized for coding agents with chunked writes) ---
+ {
+ ID: "kiro-claude-opus-4-6-agentic",
+ Object: "model",
+ Created: 1736899200, // 2025-01-15
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro Claude Opus 4.6 (Agentic)",
+ Description: "Claude Opus 4.6 optimized for coding agents (chunked writes)",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ {
+ ID: "kiro-claude-sonnet-4-6-agentic",
+ Object: "model",
+ Created: 1739836800, // 2025-02-18
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro Claude Sonnet 4.6 (Agentic)",
+ Description: "Claude Sonnet 4.6 optimized for coding agents (chunked writes)",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ {
+ ID: "kiro-claude-opus-4-5-agentic",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro Claude Opus 4.5 (Agentic)",
+ Description: "Claude Opus 4.5 optimized for coding agents (chunked writes)",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ {
+ ID: "kiro-claude-sonnet-4-5-agentic",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro Claude Sonnet 4.5 (Agentic)",
+ Description: "Claude Sonnet 4.5 optimized for coding agents (chunked writes)",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ {
+ ID: "kiro-claude-sonnet-4-agentic",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro Claude Sonnet 4 (Agentic)",
+ Description: "Claude Sonnet 4 optimized for coding agents (chunked writes)",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ {
+ ID: "kiro-claude-haiku-4-5-agentic",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Kiro Claude Haiku 4.5 (Agentic)",
+ Description: "Claude Haiku 4.5 optimized for coding agents (chunked writes)",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ },
+ }
+}
+
+// GetAmazonQModels returns the Amazon Q (AWS CodeWhisperer) model definitions.
+// These models use the same API as Kiro and share the same executor.
+func GetAmazonQModels() []*ModelInfo {
+ return []*ModelInfo{
+ {
+ ID: "amazonq-auto",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro", // Uses Kiro executor - same API
+ DisplayName: "Amazon Q Auto",
+ Description: "Automatic model selection by Amazon Q",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ },
+ {
+ ID: "amazonq-claude-opus-4.5",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Amazon Q Claude Opus 4.5",
+ Description: "Claude Opus 4.5 via Amazon Q (2.2x credit)",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ },
+ {
+ ID: "amazonq-claude-sonnet-4.5",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Amazon Q Claude Sonnet 4.5",
+ Description: "Claude Sonnet 4.5 via Amazon Q (1.3x credit)",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ },
+ {
+ ID: "amazonq-claude-sonnet-4",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Amazon Q Claude Sonnet 4",
+ Description: "Claude Sonnet 4 via Amazon Q (1.3x credit)",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ },
+ {
+ ID: "amazonq-claude-haiku-4.5",
+ Object: "model",
+ Created: 1732752000,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: "Amazon Q Claude Haiku 4.5",
+ Description: "Claude Haiku 4.5 via Amazon Q (0.4x credit)",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ },
+ }
+}
diff --git a/internal/registry/model_definitions_static_data.go b/internal/registry/model_definitions_static_data.go
index 144c4bce..52a559f9 100644
--- a/internal/registry/model_definitions_static_data.go
+++ b/internal/registry/model_definitions_static_data.go
@@ -51,6 +51,18 @@ func GetClaudeModels() []*ModelInfo {
MaxCompletionTokens: 128000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
+ {
+ ID: "claude-sonnet-4-6",
+ Object: "model",
+ Created: 1771286400, // 2026-02-17
+ OwnedBy: "anthropic",
+ Type: "claude",
+ DisplayName: "Claude 4.6 Sonnet",
+ Description: "Best combination of speed and intelligence",
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
+ },
{
ID: "claude-opus-4-5-20251101",
Object: "model",
@@ -904,10 +916,10 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
- "claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
+ "claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-6": {MaxCompletionTokens: 64000},
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"gpt-oss-120b-medium": {},
diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go
index 7b8b262e..3fa2a3b5 100644
--- a/internal/registry/model_registry.go
+++ b/internal/registry/model_registry.go
@@ -47,6 +47,8 @@ type ModelInfo struct {
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
// SupportedParameters lists supported parameters
SupportedParameters []string `json:"supported_parameters,omitempty"`
+ // SupportedEndpoints lists supported API endpoints (e.g., "/chat/completions", "/responses").
+ SupportedEndpoints []string `json:"supported_endpoints,omitempty"`
// Thinking holds provider-specific reasoning/thinking budget capabilities.
// This is optional and currently used for Gemini thinking budget normalization.
@@ -499,6 +501,9 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
if len(model.SupportedParameters) > 0 {
copyModel.SupportedParameters = append([]string(nil), model.SupportedParameters...)
}
+ if len(model.SupportedEndpoints) > 0 {
+ copyModel.SupportedEndpoints = append([]string(nil), model.SupportedEndpoints...)
+ }
return ©Model
}
@@ -1023,9 +1028,13 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
if len(model.SupportedParameters) > 0 {
result["supported_parameters"] = model.SupportedParameters
}
+ if len(model.SupportedEndpoints) > 0 {
+ result["supported_endpoints"] = model.SupportedEndpoints
+ }
return result
- case "claude":
+ case "claude", "kiro", "antigravity":
+ // Claude, Kiro, and Antigravity all use Claude-compatible format for Claude Code client
result := map[string]any{
"id": model.ID,
"object": "model",
@@ -1040,6 +1049,19 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
if model.DisplayName != "" {
result["display_name"] = model.DisplayName
}
+ // Add thinking support for Claude Code client
+ // Claude Code checks for "thinking" field (simple boolean) to enable tab toggle
+ // Also add "extended_thinking" for detailed budget info
+ if model.Thinking != nil {
+ result["thinking"] = true
+ result["extended_thinking"] = map[string]any{
+ "supported": true,
+ "min": model.Thinking.Min,
+ "max": model.Thinking.Max,
+ "zero_allowed": model.Thinking.ZeroAllowed,
+ "dynamic_allowed": model.Thinking.DynamicAllowed,
+ }
+ }
return result
case "gemini":
diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go
index 24765740..da82b8d0 100644
--- a/internal/runtime/executor/antigravity_executor.go
+++ b/internal/runtime/executor/antigravity_executor.go
@@ -1007,7 +1007,12 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
exec := &AntigravityExecutor{cfg: cfg}
token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth)
- if errToken != nil || token == "" {
+ if errToken != nil {
+ log.Warnf("antigravity executor: fetch models failed for %s: token error: %v", auth.ID, errToken)
+ return nil
+ }
+ if token == "" {
+ log.Warnf("antigravity executor: fetch models failed for %s: got empty token", auth.ID)
return nil
}
if updatedAuth != nil {
@@ -1021,6 +1026,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
modelsURL := baseURL + antigravityModelsPath
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
if errReq != nil {
+ log.Warnf("antigravity executor: fetch models failed for %s: create request error: %v", auth.ID, errReq)
return nil
}
httpReq.Header.Set("Content-Type", "application/json")
@@ -1033,12 +1039,14 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
+ log.Warnf("antigravity executor: fetch models failed for %s: context canceled: %v", auth.ID, errDo)
return nil
}
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
+ log.Warnf("antigravity executor: fetch models failed for %s: request error: %v", auth.ID, errDo)
return nil
}
@@ -1051,6 +1059,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
+ log.Warnf("antigravity executor: fetch models failed for %s: read body error: %v", auth.ID, errRead)
return nil
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
@@ -1058,11 +1067,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
+ log.Warnf("antigravity executor: fetch models failed for %s: unexpected status %d, body: %s", auth.ID, httpResp.StatusCode, string(bodyBytes))
return nil
}
result := gjson.GetBytes(bodyBytes, "models")
if !result.Exists() {
+ log.Warnf("antigravity executor: fetch models failed for %s: no models field in response, body: %s", auth.ID, string(bodyBytes))
return nil
}
diff --git a/internal/runtime/executor/cache_helpers.go b/internal/runtime/executor/cache_helpers.go
index b6de886d..1e32f43a 100644
--- a/internal/runtime/executor/cache_helpers.go
+++ b/internal/runtime/executor/cache_helpers.go
@@ -29,6 +29,7 @@ func startCodexCacheCleanup() {
go func() {
ticker := time.NewTicker(codexCacheCleanupInterval)
defer ticker.Stop()
+
for range ticker.C {
purgeExpiredCodexCache()
}
@@ -38,8 +39,10 @@ func startCodexCacheCleanup() {
// purgeExpiredCodexCache removes entries that have expired.
func purgeExpiredCodexCache() {
now := time.Now()
+
codexCacheMu.Lock()
defer codexCacheMu.Unlock()
+
for key, cache := range codexCacheMap {
if cache.Expire.Before(now) {
delete(codexCacheMap, key)
@@ -66,3 +69,10 @@ func setCodexCache(key string, cache codexCache) {
codexCacheMap[key] = cache
codexCacheMu.Unlock()
}
+
+// deleteCodexCache deletes a cache entry.
+func deleteCodexCache(key string) {
+ codexCacheMu.Lock()
+ delete(codexCacheMap, key)
+ codexCacheMu.Unlock()
+}
diff --git a/internal/runtime/executor/github_copilot_executor.go b/internal/runtime/executor/github_copilot_executor.go
new file mode 100644
index 00000000..0189ffc8
--- /dev/null
+++ b/internal/runtime/executor/github_copilot_executor.go
@@ -0,0 +1,1236 @@
+package executor
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/google/uuid"
+ copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
+ cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
+ cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
+ sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
+ log "github.com/sirupsen/logrus"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+)
+
+const (
+ githubCopilotBaseURL = "https://api.githubcopilot.com"
+ githubCopilotChatPath = "/chat/completions"
+ githubCopilotResponsesPath = "/responses"
+ githubCopilotAuthType = "github-copilot"
+ githubCopilotTokenCacheTTL = 25 * time.Minute
+ // tokenExpiryBuffer is the time before expiry when we should refresh the token.
+ tokenExpiryBuffer = 5 * time.Minute
+ // maxScannerBufferSize is the maximum buffer size for SSE scanning (20MB).
+ maxScannerBufferSize = 20_971_520
+
+ // Copilot API header values.
+ copilotUserAgent = "GitHubCopilotChat/0.35.0"
+ copilotEditorVersion = "vscode/1.107.0"
+ copilotPluginVersion = "copilot-chat/0.35.0"
+ copilotIntegrationID = "vscode-chat"
+ copilotOpenAIIntent = "conversation-panel"
+ copilotGitHubAPIVer = "2025-04-01"
+)
+
+// GitHubCopilotExecutor handles requests to the GitHub Copilot API.
+type GitHubCopilotExecutor struct {
+ cfg *config.Config
+ mu sync.RWMutex
+ cache map[string]*cachedAPIToken
+}
+
+// cachedAPIToken stores a cached Copilot API token with its expiry.
+type cachedAPIToken struct {
+ token string
+ apiEndpoint string
+ expiresAt time.Time
+}
+
+// NewGitHubCopilotExecutor constructs a new executor instance.
+func NewGitHubCopilotExecutor(cfg *config.Config) *GitHubCopilotExecutor {
+ return &GitHubCopilotExecutor{
+ cfg: cfg,
+ cache: make(map[string]*cachedAPIToken),
+ }
+}
+
+// Identifier implements ProviderExecutor.
+func (e *GitHubCopilotExecutor) Identifier() string { return githubCopilotAuthType }
+
+// PrepareRequest implements ProviderExecutor.
+func (e *GitHubCopilotExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
+ if req == nil {
+ return nil
+ }
+ ctx := req.Context()
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ apiToken, _, errToken := e.ensureAPIToken(ctx, auth)
+ if errToken != nil {
+ return errToken
+ }
+ e.applyHeaders(req, apiToken, nil)
+ return nil
+}
+
+// HttpRequest injects GitHub Copilot credentials into the request and executes it.
+func (e *GitHubCopilotExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
+ if req == nil {
+ return nil, fmt.Errorf("github-copilot executor: request is nil")
+ }
+ if ctx == nil {
+ ctx = req.Context()
+ }
+ httpReq := req.WithContext(ctx)
+ if errPrepare := e.PrepareRequest(httpReq, auth); errPrepare != nil {
+ return nil, errPrepare
+ }
+ httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
+ return httpClient.Do(httpReq)
+}
+
+// Execute handles non-streaming requests to GitHub Copilot.
+func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
+ apiToken, baseURL, errToken := e.ensureAPIToken(ctx, auth)
+ if errToken != nil {
+ return resp, errToken
+ }
+
+ reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
+ defer reporter.trackFailure(ctx, &err)
+
+ from := opts.SourceFormat
+ useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model)
+ to := sdktranslator.FromString("openai")
+ if useResponses {
+ to = sdktranslator.FromString("openai-response")
+ }
+ originalPayload := bytes.Clone(req.Payload)
+ if len(opts.OriginalRequest) > 0 {
+ originalPayload = bytes.Clone(opts.OriginalRequest)
+ }
+ originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
+ body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
+ body = e.normalizeModel(req.Model, body)
+ body = flattenAssistantContent(body)
+
+ // Detect vision content before input normalization removes messages
+ hasVision := detectVisionContent(body)
+
+ thinkingProvider := "openai"
+ if useResponses {
+ thinkingProvider = "codex"
+ }
+ body, err = thinking.ApplyThinking(body, req.Model, from.String(), thinkingProvider, e.Identifier())
+ if err != nil {
+ return resp, err
+ }
+
+ if useResponses {
+ body = normalizeGitHubCopilotResponsesInput(body)
+ body = normalizeGitHubCopilotResponsesTools(body)
+ } else {
+ body = normalizeGitHubCopilotChatTools(body)
+ }
+ requestedModel := payloadRequestedModel(opts, req.Model)
+ body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
+ body, _ = sjson.SetBytes(body, "stream", false)
+
+ path := githubCopilotChatPath
+ if useResponses {
+ path = githubCopilotResponsesPath
+ }
+ url := baseURL + path
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
+ if err != nil {
+ return resp, err
+ }
+ e.applyHeaders(httpReq, apiToken, body)
+
+ // Add Copilot-Vision-Request header if the request contains vision content
+ if hasVision {
+ httpReq.Header.Set("Copilot-Vision-Request", "true")
+ }
+
+ var authID, authLabel, authType, authValue string
+ if auth != nil {
+ authID = auth.ID
+ authLabel = auth.Label
+ authType, authValue = auth.AccountInfo()
+ }
+ recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
+ URL: url,
+ Method: http.MethodPost,
+ Headers: httpReq.Header.Clone(),
+ Body: body,
+ Provider: e.Identifier(),
+ AuthID: authID,
+ AuthLabel: authLabel,
+ AuthType: authType,
+ AuthValue: authValue,
+ })
+
+ httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
+ httpResp, err := httpClient.Do(httpReq)
+ if err != nil {
+ recordAPIResponseError(ctx, e.cfg, err)
+ return resp, err
+ }
+ defer func() {
+ if errClose := httpResp.Body.Close(); errClose != nil {
+ log.Errorf("github-copilot executor: close response body error: %v", errClose)
+ }
+ }()
+
+ recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
+
+ if !isHTTPSuccess(httpResp.StatusCode) {
+ data, _ := io.ReadAll(httpResp.Body)
+ appendAPIResponseChunk(ctx, e.cfg, data)
+ log.Debugf("github-copilot executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
+ err = statusErr{code: httpResp.StatusCode, msg: string(data)}
+ return resp, err
+ }
+
+ data, err := io.ReadAll(httpResp.Body)
+ if err != nil {
+ recordAPIResponseError(ctx, e.cfg, err)
+ return resp, err
+ }
+ appendAPIResponseChunk(ctx, e.cfg, data)
+
+ detail := parseOpenAIUsage(data)
+ if useResponses && detail.TotalTokens == 0 {
+ detail = parseOpenAIResponsesUsage(data)
+ }
+ if detail.TotalTokens > 0 {
+ reporter.publish(ctx, detail)
+ }
+
+ var param any
+ converted := ""
+ if useResponses && from.String() == "claude" {
+ converted = translateGitHubCopilotResponsesNonStreamToClaude(data)
+ } else {
+ converted = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
+ }
+ resp = cliproxyexecutor.Response{Payload: []byte(converted)}
+ reporter.ensurePublished(ctx)
+ return resp, nil
+}
+
+// ExecuteStream handles streaming requests to GitHub Copilot.
+func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
+ apiToken, baseURL, errToken := e.ensureAPIToken(ctx, auth)
+ if errToken != nil {
+ return nil, errToken
+ }
+
+ reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
+ defer reporter.trackFailure(ctx, &err)
+
+ from := opts.SourceFormat
+ useResponses := useGitHubCopilotResponsesEndpoint(from, req.Model)
+ to := sdktranslator.FromString("openai")
+ if useResponses {
+ to = sdktranslator.FromString("openai-response")
+ }
+ originalPayload := bytes.Clone(req.Payload)
+ if len(opts.OriginalRequest) > 0 {
+ originalPayload = bytes.Clone(opts.OriginalRequest)
+ }
+ originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, false)
+ body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
+ body = e.normalizeModel(req.Model, body)
+ body = flattenAssistantContent(body)
+
+ // Detect vision content before input normalization removes messages
+ hasVision := detectVisionContent(body)
+
+ thinkingProvider := "openai"
+ if useResponses {
+ thinkingProvider = "codex"
+ }
+ body, err = thinking.ApplyThinking(body, req.Model, from.String(), thinkingProvider, e.Identifier())
+ if err != nil {
+ return nil, err
+ }
+
+ if useResponses {
+ body = normalizeGitHubCopilotResponsesInput(body)
+ body = normalizeGitHubCopilotResponsesTools(body)
+ } else {
+ body = normalizeGitHubCopilotChatTools(body)
+ }
+ requestedModel := payloadRequestedModel(opts, req.Model)
+ body = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", body, originalTranslated, requestedModel)
+ body, _ = sjson.SetBytes(body, "stream", true)
+ // Enable stream options for usage stats in stream
+ if !useResponses {
+ body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
+ }
+
+ path := githubCopilotChatPath
+ if useResponses {
+ path = githubCopilotResponsesPath
+ }
+ url := baseURL + path
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
+ if err != nil {
+ return nil, err
+ }
+ e.applyHeaders(httpReq, apiToken, body)
+
+ // Add Copilot-Vision-Request header if the request contains vision content
+ if hasVision {
+ httpReq.Header.Set("Copilot-Vision-Request", "true")
+ }
+
+ var authID, authLabel, authType, authValue string
+ if auth != nil {
+ authID = auth.ID
+ authLabel = auth.Label
+ authType, authValue = auth.AccountInfo()
+ }
+ recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
+ URL: url,
+ Method: http.MethodPost,
+ Headers: httpReq.Header.Clone(),
+ Body: body,
+ Provider: e.Identifier(),
+ AuthID: authID,
+ AuthLabel: authLabel,
+ AuthType: authType,
+ AuthValue: authValue,
+ })
+
+ httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
+ httpResp, err := httpClient.Do(httpReq)
+ if err != nil {
+ recordAPIResponseError(ctx, e.cfg, err)
+ return nil, err
+ }
+
+ recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
+
+ if !isHTTPSuccess(httpResp.StatusCode) {
+ data, readErr := io.ReadAll(httpResp.Body)
+ if errClose := httpResp.Body.Close(); errClose != nil {
+ log.Errorf("github-copilot executor: close response body error: %v", errClose)
+ }
+ if readErr != nil {
+ recordAPIResponseError(ctx, e.cfg, readErr)
+ return nil, readErr
+ }
+ appendAPIResponseChunk(ctx, e.cfg, data)
+ log.Debugf("github-copilot executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
+ err = statusErr{code: httpResp.StatusCode, msg: string(data)}
+ return nil, err
+ }
+
+ out := make(chan cliproxyexecutor.StreamChunk)
+ stream = out
+
+ go func() {
+ defer close(out)
+ defer func() {
+ if errClose := httpResp.Body.Close(); errClose != nil {
+ log.Errorf("github-copilot executor: close response body error: %v", errClose)
+ }
+ }()
+
+ scanner := bufio.NewScanner(httpResp.Body)
+ scanner.Buffer(nil, maxScannerBufferSize)
+ var param any
+
+ for scanner.Scan() {
+ line := scanner.Bytes()
+ appendAPIResponseChunk(ctx, e.cfg, line)
+
+ // Parse SSE data
+ if bytes.HasPrefix(line, dataTag) {
+ data := bytes.TrimSpace(line[5:])
+ if bytes.Equal(data, []byte("[DONE]")) {
+ continue
+ }
+ if detail, ok := parseOpenAIStreamUsage(line); ok {
+ reporter.publish(ctx, detail)
+ } else if useResponses {
+ if detail, ok := parseOpenAIResponsesStreamUsage(line); ok {
+ reporter.publish(ctx, detail)
+ }
+ }
+ }
+
+ var chunks []string
+ if useResponses && from.String() == "claude" {
+ chunks = translateGitHubCopilotResponsesStreamToClaude(bytes.Clone(line), ¶m)
+ } else {
+ chunks = sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
+ }
+ for i := range chunks {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
+ }
+ }
+
+ if errScan := scanner.Err(); errScan != nil {
+ recordAPIResponseError(ctx, e.cfg, errScan)
+ reporter.publishFailure(ctx)
+ out <- cliproxyexecutor.StreamChunk{Err: errScan}
+ } else {
+ reporter.ensurePublished(ctx)
+ }
+ }()
+
+ return stream, nil
+}
+
+// CountTokens is not supported for GitHub Copilot.
+func (e *GitHubCopilotExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
+ return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported for github-copilot"}
+}
+
+// Refresh validates the GitHub token is still working.
+// GitHub OAuth tokens don't expire traditionally, so we just validate.
+func (e *GitHubCopilotExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
+ if auth == nil {
+ return nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
+ }
+
+ // Get the GitHub access token
+ accessToken := metaStringValue(auth.Metadata, "access_token")
+ if accessToken == "" {
+ return auth, nil
+ }
+
+ // Validate the token can still get a Copilot API token
+ copilotAuth := copilotauth.NewCopilotAuth(e.cfg)
+ _, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken)
+ if err != nil {
+ return nil, statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("github-copilot token validation failed: %v", err)}
+ }
+
+ return auth, nil
+}
+
+// ensureAPIToken gets or refreshes the Copilot API token.
+func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *cliproxyauth.Auth) (string, string, error) {
+ if auth == nil {
+ return "", "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
+ }
+
+ // Get the GitHub access token
+ accessToken := metaStringValue(auth.Metadata, "access_token")
+ if accessToken == "" {
+ return "", "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"}
+ }
+
+ // Check for cached API token using thread-safe access
+ e.mu.RLock()
+ if cached, ok := e.cache[accessToken]; ok && cached.expiresAt.After(time.Now().Add(tokenExpiryBuffer)) {
+ e.mu.RUnlock()
+ return cached.token, cached.apiEndpoint, nil
+ }
+ e.mu.RUnlock()
+
+ // Get a new Copilot API token
+ copilotAuth := copilotauth.NewCopilotAuth(e.cfg)
+ apiToken, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken)
+ if err != nil {
+ return "", "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)}
+ }
+
+ // Use endpoint from token response, fall back to default
+ apiEndpoint := githubCopilotBaseURL
+ if apiToken.Endpoints.API != "" {
+ apiEndpoint = strings.TrimRight(apiToken.Endpoints.API, "/")
+ }
+
+ // Cache the token with thread-safe access
+ expiresAt := time.Now().Add(githubCopilotTokenCacheTTL)
+ if apiToken.ExpiresAt > 0 {
+ expiresAt = time.Unix(apiToken.ExpiresAt, 0)
+ }
+ e.mu.Lock()
+ e.cache[accessToken] = &cachedAPIToken{
+ token: apiToken.Token,
+ apiEndpoint: apiEndpoint,
+ expiresAt: expiresAt,
+ }
+ e.mu.Unlock()
+
+ return apiToken.Token, apiEndpoint, nil
+}
+
+// applyHeaders sets the required headers for GitHub Copilot API requests.
+func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string, body []byte) {
+ r.Header.Set("Content-Type", "application/json")
+ r.Header.Set("Authorization", "Bearer "+apiToken)
+ r.Header.Set("Accept", "application/json")
+ r.Header.Set("User-Agent", copilotUserAgent)
+ r.Header.Set("Editor-Version", copilotEditorVersion)
+ r.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
+ r.Header.Set("Openai-Intent", copilotOpenAIIntent)
+ r.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
+ r.Header.Set("X-Github-Api-Version", copilotGitHubAPIVer)
+ r.Header.Set("X-Request-Id", uuid.NewString())
+
+ initiator := "user"
+ if len(body) > 0 {
+ if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
+ for _, msg := range messages.Array() {
+ role := msg.Get("role").String()
+ if role == "assistant" || role == "tool" {
+ initiator = "agent"
+ break
+ }
+ }
+ }
+ }
+ r.Header.Set("X-Initiator", initiator)
+}
+
+// detectVisionContent checks if the request body contains vision/image content.
+// Returns true if the request includes image_url or image type content blocks.
+func detectVisionContent(body []byte) bool {
+ // Parse messages array
+ messagesResult := gjson.GetBytes(body, "messages")
+ if !messagesResult.Exists() || !messagesResult.IsArray() {
+ return false
+ }
+
+ // Check each message for vision content
+ for _, message := range messagesResult.Array() {
+ content := message.Get("content")
+
+ // If content is an array, check each content block
+ if content.IsArray() {
+ for _, block := range content.Array() {
+ blockType := block.Get("type").String()
+ // Check for image_url or image type
+ if blockType == "image_url" || blockType == "image" {
+ return true
+ }
+ }
+ }
+ }
+
+ return false
+}
+
+// normalizeModel strips the suffix (e.g. "(medium)") from the model name
+// before sending to GitHub Copilot, as the upstream API does not accept
+// suffixed model identifiers.
+func (e *GitHubCopilotExecutor) normalizeModel(model string, body []byte) []byte {
+ baseModel := thinking.ParseSuffix(model).ModelName
+ if baseModel != model {
+ body, _ = sjson.SetBytes(body, "model", baseModel)
+ }
+ return body
+}
+
+func useGitHubCopilotResponsesEndpoint(sourceFormat sdktranslator.Format, model string) bool {
+ if sourceFormat.String() == "openai-response" {
+ return true
+ }
+ baseModel := strings.ToLower(thinking.ParseSuffix(model).ModelName)
+ return strings.Contains(baseModel, "codex")
+}
+
+// flattenAssistantContent converts assistant message content from array format
+// to a joined string. GitHub Copilot requires assistant content as a string;
+// sending it as an array causes Claude models to re-answer all previous prompts.
+func flattenAssistantContent(body []byte) []byte {
+ messages := gjson.GetBytes(body, "messages")
+ if !messages.Exists() || !messages.IsArray() {
+ return body
+ }
+ result := body
+ for i, msg := range messages.Array() {
+ if msg.Get("role").String() != "assistant" {
+ continue
+ }
+ content := msg.Get("content")
+ if !content.Exists() || !content.IsArray() {
+ continue
+ }
+ // Skip flattening if the content contains non-text blocks (tool_use, thinking, etc.)
+ hasNonText := false
+ for _, part := range content.Array() {
+ if t := part.Get("type").String(); t != "" && t != "text" {
+ hasNonText = true
+ break
+ }
+ }
+ if hasNonText {
+ continue
+ }
+ var textParts []string
+ for _, part := range content.Array() {
+ if part.Get("type").String() == "text" {
+ if t := part.Get("text").String(); t != "" {
+ textParts = append(textParts, t)
+ }
+ }
+ }
+ joined := strings.Join(textParts, "")
+ path := fmt.Sprintf("messages.%d.content", i)
+ result, _ = sjson.SetBytes(result, path, joined)
+ }
+ return result
+}
+
+func normalizeGitHubCopilotChatTools(body []byte) []byte {
+ tools := gjson.GetBytes(body, "tools")
+ if tools.Exists() {
+ filtered := "[]"
+ if tools.IsArray() {
+ for _, tool := range tools.Array() {
+ if tool.Get("type").String() != "function" {
+ continue
+ }
+ filtered, _ = sjson.SetRaw(filtered, "-1", tool.Raw)
+ }
+ }
+ body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered))
+ }
+
+ toolChoice := gjson.GetBytes(body, "tool_choice")
+ if !toolChoice.Exists() {
+ return body
+ }
+ if toolChoice.Type == gjson.String {
+ switch toolChoice.String() {
+ case "auto", "none", "required":
+ return body
+ }
+ }
+ body, _ = sjson.SetBytes(body, "tool_choice", "auto")
+ return body
+}
+
+func normalizeGitHubCopilotResponsesInput(body []byte) []byte {
+ input := gjson.GetBytes(body, "input")
+ if input.Exists() {
+ // If input is already a string or array, keep it as-is.
+ if input.Type == gjson.String || input.IsArray() {
+ return body
+ }
+ // Non-string/non-array input: stringify as fallback.
+ body, _ = sjson.SetBytes(body, "input", input.Raw)
+ return body
+ }
+
+ // Convert Claude messages format to OpenAI Responses API input array.
+ // This preserves the conversation structure (roles, tool calls, tool results)
+ // which is critical for multi-turn tool-use conversations.
+ inputArr := "[]"
+
+ // System messages → developer role
+ if system := gjson.GetBytes(body, "system"); system.Exists() {
+ var systemParts []string
+ if system.IsArray() {
+ for _, part := range system.Array() {
+ if txt := part.Get("text").String(); txt != "" {
+ systemParts = append(systemParts, txt)
+ }
+ }
+ } else if system.Type == gjson.String {
+ systemParts = append(systemParts, system.String())
+ }
+ if len(systemParts) > 0 {
+ msg := `{"type":"message","role":"developer","content":[]}`
+ for _, txt := range systemParts {
+ part := `{"type":"input_text","text":""}`
+ part, _ = sjson.Set(part, "text", txt)
+ msg, _ = sjson.SetRaw(msg, "content.-1", part)
+ }
+ inputArr, _ = sjson.SetRaw(inputArr, "-1", msg)
+ }
+ }
+
+ // Messages → structured input items
+ if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
+ for _, msg := range messages.Array() {
+ role := msg.Get("role").String()
+ content := msg.Get("content")
+
+ if !content.Exists() {
+ continue
+ }
+
+ // Simple string content
+ if content.Type == gjson.String {
+ textType := "input_text"
+ if role == "assistant" {
+ textType = "output_text"
+ }
+ item := `{"type":"message","role":"","content":[]}`
+ item, _ = sjson.Set(item, "role", role)
+ part := fmt.Sprintf(`{"type":"%s","text":""}`, textType)
+ part, _ = sjson.Set(part, "text", content.String())
+ item, _ = sjson.SetRaw(item, "content.-1", part)
+ inputArr, _ = sjson.SetRaw(inputArr, "-1", item)
+ continue
+ }
+
+ if !content.IsArray() {
+ continue
+ }
+
+ // Array content: split into message parts vs tool items
+ var msgParts []string
+ for _, c := range content.Array() {
+ cType := c.Get("type").String()
+ switch cType {
+ case "text":
+ textType := "input_text"
+ if role == "assistant" {
+ textType = "output_text"
+ }
+ part := fmt.Sprintf(`{"type":"%s","text":""}`, textType)
+ part, _ = sjson.Set(part, "text", c.Get("text").String())
+ msgParts = append(msgParts, part)
+ case "image":
+ source := c.Get("source")
+ if source.Exists() {
+ data := source.Get("data").String()
+ if data == "" {
+ data = source.Get("base64").String()
+ }
+ mediaType := source.Get("media_type").String()
+ if mediaType == "" {
+ mediaType = source.Get("mime_type").String()
+ }
+ if mediaType == "" {
+ mediaType = "application/octet-stream"
+ }
+ if data != "" {
+ part := `{"type":"input_image","image_url":""}`
+ part, _ = sjson.Set(part, "image_url", fmt.Sprintf("data:%s;base64,%s", mediaType, data))
+ msgParts = append(msgParts, part)
+ }
+ }
+ case "tool_use":
+ // Flush any accumulated message parts first
+ if len(msgParts) > 0 {
+ item := `{"type":"message","role":"","content":[]}`
+ item, _ = sjson.Set(item, "role", role)
+ for _, p := range msgParts {
+ item, _ = sjson.SetRaw(item, "content.-1", p)
+ }
+ inputArr, _ = sjson.SetRaw(inputArr, "-1", item)
+ msgParts = nil
+ }
+ fc := `{"type":"function_call","call_id":"","name":"","arguments":""}`
+ fc, _ = sjson.Set(fc, "call_id", c.Get("id").String())
+ fc, _ = sjson.Set(fc, "name", c.Get("name").String())
+ if inputRaw := c.Get("input"); inputRaw.Exists() {
+ fc, _ = sjson.Set(fc, "arguments", inputRaw.Raw)
+ }
+ inputArr, _ = sjson.SetRaw(inputArr, "-1", fc)
+ case "tool_result":
+ // Flush any accumulated message parts first
+ if len(msgParts) > 0 {
+ item := `{"type":"message","role":"","content":[]}`
+ item, _ = sjson.Set(item, "role", role)
+ for _, p := range msgParts {
+ item, _ = sjson.SetRaw(item, "content.-1", p)
+ }
+ inputArr, _ = sjson.SetRaw(inputArr, "-1", item)
+ msgParts = nil
+ }
+ fco := `{"type":"function_call_output","call_id":"","output":""}`
+ fco, _ = sjson.Set(fco, "call_id", c.Get("tool_use_id").String())
+ // Extract output text
+ resultContent := c.Get("content")
+ if resultContent.Type == gjson.String {
+ fco, _ = sjson.Set(fco, "output", resultContent.String())
+ } else if resultContent.IsArray() {
+ var resultParts []string
+ for _, rc := range resultContent.Array() {
+ if txt := rc.Get("text").String(); txt != "" {
+ resultParts = append(resultParts, txt)
+ }
+ }
+ fco, _ = sjson.Set(fco, "output", strings.Join(resultParts, "\n"))
+ } else if resultContent.Exists() {
+ fco, _ = sjson.Set(fco, "output", resultContent.String())
+ }
+ inputArr, _ = sjson.SetRaw(inputArr, "-1", fco)
+ case "thinking":
+ // Skip thinking blocks - not part of the API input
+ }
+ }
+
+ // Flush remaining message parts
+ if len(msgParts) > 0 {
+ item := `{"type":"message","role":"","content":[]}`
+ item, _ = sjson.Set(item, "role", role)
+ for _, p := range msgParts {
+ item, _ = sjson.SetRaw(item, "content.-1", p)
+ }
+ inputArr, _ = sjson.SetRaw(inputArr, "-1", item)
+ }
+ }
+ }
+
+ body, _ = sjson.SetRawBytes(body, "input", []byte(inputArr))
+ // Remove messages/system since we've converted them to input
+ body, _ = sjson.DeleteBytes(body, "messages")
+ body, _ = sjson.DeleteBytes(body, "system")
+ return body
+}
+
+func normalizeGitHubCopilotResponsesTools(body []byte) []byte {
+ tools := gjson.GetBytes(body, "tools")
+ if tools.Exists() {
+ filtered := "[]"
+ if tools.IsArray() {
+ for _, tool := range tools.Array() {
+ toolType := tool.Get("type").String()
+ // Accept OpenAI format (type="function") and Claude format
+ // (no type field, but has top-level name + input_schema).
+ if toolType != "" && toolType != "function" {
+ continue
+ }
+ name := tool.Get("name").String()
+ if name == "" {
+ name = tool.Get("function.name").String()
+ }
+ if name == "" {
+ continue
+ }
+ normalized := `{"type":"function","name":""}`
+ normalized, _ = sjson.Set(normalized, "name", name)
+ if desc := tool.Get("description").String(); desc != "" {
+ normalized, _ = sjson.Set(normalized, "description", desc)
+ } else if desc = tool.Get("function.description").String(); desc != "" {
+ normalized, _ = sjson.Set(normalized, "description", desc)
+ }
+ if params := tool.Get("parameters"); params.Exists() {
+ normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
+ } else if params = tool.Get("function.parameters"); params.Exists() {
+ normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
+ } else if params = tool.Get("input_schema"); params.Exists() {
+ normalized, _ = sjson.SetRaw(normalized, "parameters", params.Raw)
+ }
+ filtered, _ = sjson.SetRaw(filtered, "-1", normalized)
+ }
+ }
+ body, _ = sjson.SetRawBytes(body, "tools", []byte(filtered))
+ }
+
+ toolChoice := gjson.GetBytes(body, "tool_choice")
+ if !toolChoice.Exists() {
+ return body
+ }
+ if toolChoice.Type == gjson.String {
+ switch toolChoice.String() {
+ case "auto", "none", "required":
+ return body
+ default:
+ body, _ = sjson.SetBytes(body, "tool_choice", "auto")
+ return body
+ }
+ }
+ if toolChoice.Type == gjson.JSON {
+ choiceType := toolChoice.Get("type").String()
+ if choiceType == "function" {
+ name := toolChoice.Get("name").String()
+ if name == "" {
+ name = toolChoice.Get("function.name").String()
+ }
+ if name != "" {
+ normalized := `{"type":"function","name":""}`
+ normalized, _ = sjson.Set(normalized, "name", name)
+ body, _ = sjson.SetRawBytes(body, "tool_choice", []byte(normalized))
+ return body
+ }
+ }
+ }
+ body, _ = sjson.SetBytes(body, "tool_choice", "auto")
+ return body
+}
+
+func collectTextFromNode(node gjson.Result) string {
+ if !node.Exists() {
+ return ""
+ }
+ if node.Type == gjson.String {
+ return node.String()
+ }
+ if node.IsArray() {
+ var parts []string
+ for _, item := range node.Array() {
+ if item.Type == gjson.String {
+ if text := item.String(); text != "" {
+ parts = append(parts, text)
+ }
+ continue
+ }
+ if text := item.Get("text").String(); text != "" {
+ parts = append(parts, text)
+ continue
+ }
+ if nested := collectTextFromNode(item.Get("content")); nested != "" {
+ parts = append(parts, nested)
+ }
+ }
+ return strings.Join(parts, "\n")
+ }
+ if node.Type == gjson.JSON {
+ if text := node.Get("text").String(); text != "" {
+ return text
+ }
+ if nested := collectTextFromNode(node.Get("content")); nested != "" {
+ return nested
+ }
+ return node.Raw
+ }
+ return node.String()
+}
+
+type githubCopilotResponsesStreamToolState struct {
+ Index int
+ ID string
+ Name string
+}
+
+type githubCopilotResponsesStreamState struct {
+ MessageStarted bool
+ MessageStopSent bool
+ TextBlockStarted bool
+ TextBlockIndex int
+ NextContentIndex int
+ HasToolUse bool
+ ReasoningActive bool
+ ReasoningIndex int
+ OutputIndexToTool map[int]*githubCopilotResponsesStreamToolState
+ ItemIDToTool map[string]*githubCopilotResponsesStreamToolState
+}
+
+func translateGitHubCopilotResponsesNonStreamToClaude(data []byte) string {
+ root := gjson.ParseBytes(data)
+ out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
+ out, _ = sjson.Set(out, "id", root.Get("id").String())
+ out, _ = sjson.Set(out, "model", root.Get("model").String())
+
+ hasToolUse := false
+ if output := root.Get("output"); output.Exists() && output.IsArray() {
+ for _, item := range output.Array() {
+ switch item.Get("type").String() {
+ case "reasoning":
+ var thinkingText string
+ if summary := item.Get("summary"); summary.Exists() && summary.IsArray() {
+ var parts []string
+ for _, part := range summary.Array() {
+ if txt := part.Get("text").String(); txt != "" {
+ parts = append(parts, txt)
+ }
+ }
+ thinkingText = strings.Join(parts, "")
+ }
+ if thinkingText == "" {
+ if content := item.Get("content"); content.Exists() && content.IsArray() {
+ var parts []string
+ for _, part := range content.Array() {
+ if txt := part.Get("text").String(); txt != "" {
+ parts = append(parts, txt)
+ }
+ }
+ thinkingText = strings.Join(parts, "")
+ }
+ }
+ if thinkingText != "" {
+ block := `{"type":"thinking","thinking":""}`
+ block, _ = sjson.Set(block, "thinking", thinkingText)
+ out, _ = sjson.SetRaw(out, "content.-1", block)
+ }
+ case "message":
+ if content := item.Get("content"); content.Exists() && content.IsArray() {
+ for _, part := range content.Array() {
+ if part.Get("type").String() != "output_text" {
+ continue
+ }
+ text := part.Get("text").String()
+ if text == "" {
+ continue
+ }
+ block := `{"type":"text","text":""}`
+ block, _ = sjson.Set(block, "text", text)
+ out, _ = sjson.SetRaw(out, "content.-1", block)
+ }
+ }
+ case "function_call":
+ hasToolUse = true
+ toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
+ toolID := item.Get("call_id").String()
+ if toolID == "" {
+ toolID = item.Get("id").String()
+ }
+ toolUse, _ = sjson.Set(toolUse, "id", toolID)
+ toolUse, _ = sjson.Set(toolUse, "name", item.Get("name").String())
+ if args := item.Get("arguments").String(); args != "" && gjson.Valid(args) {
+ argObj := gjson.Parse(args)
+ if argObj.IsObject() {
+ toolUse, _ = sjson.SetRaw(toolUse, "input", argObj.Raw)
+ }
+ }
+ out, _ = sjson.SetRaw(out, "content.-1", toolUse)
+ }
+ }
+ }
+
+ inputTokens := root.Get("usage.input_tokens").Int()
+ outputTokens := root.Get("usage.output_tokens").Int()
+ cachedTokens := root.Get("usage.input_tokens_details.cached_tokens").Int()
+ if cachedTokens > 0 && inputTokens >= cachedTokens {
+ inputTokens -= cachedTokens
+ }
+ out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
+ out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
+ if cachedTokens > 0 {
+ out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
+ }
+ if hasToolUse {
+ out, _ = sjson.Set(out, "stop_reason", "tool_use")
+ } else if sr := root.Get("stop_reason").String(); sr == "max_tokens" || sr == "stop" {
+ out, _ = sjson.Set(out, "stop_reason", sr)
+ } else {
+ out, _ = sjson.Set(out, "stop_reason", "end_turn")
+ }
+ return out
+}
+
+func translateGitHubCopilotResponsesStreamToClaude(line []byte, param *any) []string {
+ if *param == nil {
+ *param = &githubCopilotResponsesStreamState{
+ TextBlockIndex: -1,
+ OutputIndexToTool: make(map[int]*githubCopilotResponsesStreamToolState),
+ ItemIDToTool: make(map[string]*githubCopilotResponsesStreamToolState),
+ }
+ }
+ state := (*param).(*githubCopilotResponsesStreamState)
+
+ if !bytes.HasPrefix(line, dataTag) {
+ return nil
+ }
+ payload := bytes.TrimSpace(line[5:])
+ if bytes.Equal(payload, []byte("[DONE]")) {
+ return nil
+ }
+ if !gjson.ValidBytes(payload) {
+ return nil
+ }
+
+ event := gjson.GetBytes(payload, "type").String()
+ results := make([]string, 0, 4)
+ ensureMessageStart := func() {
+ if state.MessageStarted {
+ return
+ }
+ messageStart := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
+ messageStart, _ = sjson.Set(messageStart, "message.id", gjson.GetBytes(payload, "response.id").String())
+ messageStart, _ = sjson.Set(messageStart, "message.model", gjson.GetBytes(payload, "response.model").String())
+ results = append(results, "event: message_start\ndata: "+messageStart+"\n\n")
+ state.MessageStarted = true
+ }
+ startTextBlockIfNeeded := func() {
+ if state.TextBlockStarted {
+ return
+ }
+ if state.TextBlockIndex < 0 {
+ state.TextBlockIndex = state.NextContentIndex
+ state.NextContentIndex++
+ }
+ contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
+ contentBlockStart, _ = sjson.Set(contentBlockStart, "index", state.TextBlockIndex)
+ results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
+ state.TextBlockStarted = true
+ }
+ stopTextBlockIfNeeded := func() {
+ if !state.TextBlockStarted {
+ return
+ }
+ contentBlockStop := `{"type":"content_block_stop","index":0}`
+ contentBlockStop, _ = sjson.Set(contentBlockStop, "index", state.TextBlockIndex)
+ results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
+ state.TextBlockStarted = false
+ state.TextBlockIndex = -1
+ }
+ resolveTool := func(itemID string, outputIndex int) *githubCopilotResponsesStreamToolState {
+ if itemID != "" {
+ if tool, ok := state.ItemIDToTool[itemID]; ok {
+ return tool
+ }
+ }
+ if tool, ok := state.OutputIndexToTool[outputIndex]; ok {
+ if itemID != "" {
+ state.ItemIDToTool[itemID] = tool
+ }
+ return tool
+ }
+ return nil
+ }
+
+ switch event {
+ case "response.created":
+ ensureMessageStart()
+ case "response.output_text.delta":
+ ensureMessageStart()
+ startTextBlockIfNeeded()
+ delta := gjson.GetBytes(payload, "delta").String()
+ if delta != "" {
+ contentDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
+ contentDelta, _ = sjson.Set(contentDelta, "index", state.TextBlockIndex)
+ contentDelta, _ = sjson.Set(contentDelta, "delta.text", delta)
+ results = append(results, "event: content_block_delta\ndata: "+contentDelta+"\n\n")
+ }
+ case "response.reasoning_summary_part.added":
+ ensureMessageStart()
+ state.ReasoningActive = true
+ state.ReasoningIndex = state.NextContentIndex
+ state.NextContentIndex++
+ thinkingStart := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
+ thinkingStart, _ = sjson.Set(thinkingStart, "index", state.ReasoningIndex)
+ results = append(results, "event: content_block_start\ndata: "+thinkingStart+"\n\n")
+ case "response.reasoning_summary_text.delta":
+ if state.ReasoningActive {
+ delta := gjson.GetBytes(payload, "delta").String()
+ if delta != "" {
+ thinkingDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
+ thinkingDelta, _ = sjson.Set(thinkingDelta, "index", state.ReasoningIndex)
+ thinkingDelta, _ = sjson.Set(thinkingDelta, "delta.thinking", delta)
+ results = append(results, "event: content_block_delta\ndata: "+thinkingDelta+"\n\n")
+ }
+ }
+ case "response.reasoning_summary_part.done":
+ if state.ReasoningActive {
+ thinkingStop := `{"type":"content_block_stop","index":0}`
+ thinkingStop, _ = sjson.Set(thinkingStop, "index", state.ReasoningIndex)
+ results = append(results, "event: content_block_stop\ndata: "+thinkingStop+"\n\n")
+ state.ReasoningActive = false
+ }
+ case "response.output_item.added":
+ if gjson.GetBytes(payload, "item.type").String() != "function_call" {
+ break
+ }
+ ensureMessageStart()
+ stopTextBlockIfNeeded()
+ state.HasToolUse = true
+ tool := &githubCopilotResponsesStreamToolState{
+ Index: state.NextContentIndex,
+ ID: gjson.GetBytes(payload, "item.call_id").String(),
+ Name: gjson.GetBytes(payload, "item.name").String(),
+ }
+ if tool.ID == "" {
+ tool.ID = gjson.GetBytes(payload, "item.id").String()
+ }
+ state.NextContentIndex++
+ outputIndex := int(gjson.GetBytes(payload, "output_index").Int())
+ state.OutputIndexToTool[outputIndex] = tool
+ if itemID := gjson.GetBytes(payload, "item.id").String(); itemID != "" {
+ state.ItemIDToTool[itemID] = tool
+ }
+ contentBlockStart := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
+ contentBlockStart, _ = sjson.Set(contentBlockStart, "index", tool.Index)
+ contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.id", tool.ID)
+ contentBlockStart, _ = sjson.Set(contentBlockStart, "content_block.name", tool.Name)
+ results = append(results, "event: content_block_start\ndata: "+contentBlockStart+"\n\n")
+ case "response.output_item.delta":
+ item := gjson.GetBytes(payload, "item")
+ if item.Get("type").String() != "function_call" {
+ break
+ }
+ tool := resolveTool(item.Get("id").String(), int(gjson.GetBytes(payload, "output_index").Int()))
+ if tool == nil {
+ break
+ }
+ partial := gjson.GetBytes(payload, "delta").String()
+ if partial == "" {
+ partial = item.Get("arguments").String()
+ }
+ if partial == "" {
+ break
+ }
+ inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
+ inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
+ inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
+ results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
+ case "response.function_call_arguments.delta":
+ // Copilot sends tool call arguments via this event type (not response.output_item.delta).
+ // Data format: {"delta":"...", "item_id":"...", "output_index":N, ...}
+ itemID := gjson.GetBytes(payload, "item_id").String()
+ outputIndex := int(gjson.GetBytes(payload, "output_index").Int())
+ tool := resolveTool(itemID, outputIndex)
+ if tool == nil {
+ break
+ }
+ partial := gjson.GetBytes(payload, "delta").String()
+ if partial == "" {
+ break
+ }
+ inputDelta := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
+ inputDelta, _ = sjson.Set(inputDelta, "index", tool.Index)
+ inputDelta, _ = sjson.Set(inputDelta, "delta.partial_json", partial)
+ results = append(results, "event: content_block_delta\ndata: "+inputDelta+"\n\n")
+ case "response.output_item.done":
+ if gjson.GetBytes(payload, "item.type").String() != "function_call" {
+ break
+ }
+ tool := resolveTool(gjson.GetBytes(payload, "item.id").String(), int(gjson.GetBytes(payload, "output_index").Int()))
+ if tool == nil {
+ break
+ }
+ contentBlockStop := `{"type":"content_block_stop","index":0}`
+ contentBlockStop, _ = sjson.Set(contentBlockStop, "index", tool.Index)
+ results = append(results, "event: content_block_stop\ndata: "+contentBlockStop+"\n\n")
+ case "response.completed":
+ ensureMessageStart()
+ stopTextBlockIfNeeded()
+ if !state.MessageStopSent {
+ stopReason := "end_turn"
+ if state.HasToolUse {
+ stopReason = "tool_use"
+ } else if sr := gjson.GetBytes(payload, "response.stop_reason").String(); sr == "max_tokens" || sr == "stop" {
+ stopReason = sr
+ }
+ inputTokens := gjson.GetBytes(payload, "response.usage.input_tokens").Int()
+ outputTokens := gjson.GetBytes(payload, "response.usage.output_tokens").Int()
+ cachedTokens := gjson.GetBytes(payload, "response.usage.input_tokens_details.cached_tokens").Int()
+ if cachedTokens > 0 && inputTokens >= cachedTokens {
+ inputTokens -= cachedTokens
+ }
+ messageDelta := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
+ messageDelta, _ = sjson.Set(messageDelta, "delta.stop_reason", stopReason)
+ messageDelta, _ = sjson.Set(messageDelta, "usage.input_tokens", inputTokens)
+ messageDelta, _ = sjson.Set(messageDelta, "usage.output_tokens", outputTokens)
+ if cachedTokens > 0 {
+ messageDelta, _ = sjson.Set(messageDelta, "usage.cache_read_input_tokens", cachedTokens)
+ }
+ results = append(results, "event: message_delta\ndata: "+messageDelta+"\n\n")
+ results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
+ state.MessageStopSent = true
+ }
+ }
+
+ return results
+}
+
+// isHTTPSuccess checks if the status code indicates success (2xx).
+func isHTTPSuccess(statusCode int) bool {
+ return statusCode >= 200 && statusCode < 300
+}
diff --git a/internal/runtime/executor/github_copilot_executor_test.go b/internal/runtime/executor/github_copilot_executor_test.go
new file mode 100644
index 00000000..39868ef7
--- /dev/null
+++ b/internal/runtime/executor/github_copilot_executor_test.go
@@ -0,0 +1,333 @@
+package executor
+
+import (
+ "net/http"
+ "strings"
+ "testing"
+
+ sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
+ "github.com/tidwall/gjson"
+)
+
+func TestGitHubCopilotNormalizeModel_StripsSuffix(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ model string
+ wantModel string
+ }{
+ {
+ name: "suffix stripped",
+ model: "claude-opus-4.6(medium)",
+ wantModel: "claude-opus-4.6",
+ },
+ {
+ name: "no suffix unchanged",
+ model: "claude-opus-4.6",
+ wantModel: "claude-opus-4.6",
+ },
+ {
+ name: "different suffix stripped",
+ model: "gpt-4o(high)",
+ wantModel: "gpt-4o",
+ },
+ {
+ name: "numeric suffix stripped",
+ model: "gemini-2.5-pro(8192)",
+ wantModel: "gemini-2.5-pro",
+ },
+ }
+
+ e := &GitHubCopilotExecutor{}
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ body := []byte(`{"model":"` + tt.model + `","messages":[]}`)
+ got := e.normalizeModel(tt.model, body)
+
+ gotModel := gjson.GetBytes(got, "model").String()
+ if gotModel != tt.wantModel {
+ t.Fatalf("normalizeModel() model = %q, want %q", gotModel, tt.wantModel)
+ }
+ })
+ }
+}
+
+func TestUseGitHubCopilotResponsesEndpoint_OpenAIResponseSource(t *testing.T) {
+ t.Parallel()
+ if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai-response"), "claude-3-5-sonnet") {
+ t.Fatal("expected openai-response source to use /responses")
+ }
+}
+
+func TestUseGitHubCopilotResponsesEndpoint_CodexModel(t *testing.T) {
+ t.Parallel()
+ if !useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "gpt-5-codex") {
+ t.Fatal("expected codex model to use /responses")
+ }
+}
+
+func TestUseGitHubCopilotResponsesEndpoint_DefaultChat(t *testing.T) {
+ t.Parallel()
+ if useGitHubCopilotResponsesEndpoint(sdktranslator.FromString("openai"), "claude-3-5-sonnet") {
+ t.Fatal("expected default openai source with non-codex model to use /chat/completions")
+ }
+}
+
+func TestNormalizeGitHubCopilotChatTools_KeepFunctionOnly(t *testing.T) {
+ t.Parallel()
+ body := []byte(`{"tools":[{"type":"function","function":{"name":"ok"}},{"type":"code_interpreter"}],"tool_choice":"auto"}`)
+ got := normalizeGitHubCopilotChatTools(body)
+ tools := gjson.GetBytes(got, "tools").Array()
+ if len(tools) != 1 {
+ t.Fatalf("tools len = %d, want 1", len(tools))
+ }
+ if tools[0].Get("type").String() != "function" {
+ t.Fatalf("tool type = %q, want function", tools[0].Get("type").String())
+ }
+}
+
+func TestNormalizeGitHubCopilotChatTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
+ t.Parallel()
+ body := []byte(`{"tools":[],"tool_choice":{"type":"function","function":{"name":"x"}}}`)
+ got := normalizeGitHubCopilotChatTools(body)
+ if gjson.GetBytes(got, "tool_choice").String() != "auto" {
+ t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
+ }
+}
+
+func TestNormalizeGitHubCopilotResponsesInput_MissingInputExtractedFromSystemAndMessages(t *testing.T) {
+ t.Parallel()
+ body := []byte(`{"system":"sys text","messages":[{"role":"user","content":"user text"},{"role":"assistant","content":[{"type":"text","text":"assistant text"}]}]}`)
+ got := normalizeGitHubCopilotResponsesInput(body)
+ in := gjson.GetBytes(got, "input")
+ if !in.IsArray() {
+ t.Fatalf("input type = %v, want array", in.Type)
+ }
+ raw := in.Raw
+ if !strings.Contains(raw, "sys text") || !strings.Contains(raw, "user text") || !strings.Contains(raw, "assistant text") {
+ t.Fatalf("input = %s, want structured array with all texts", raw)
+ }
+ if gjson.GetBytes(got, "messages").Exists() {
+ t.Fatal("messages should be removed after conversion")
+ }
+ if gjson.GetBytes(got, "system").Exists() {
+ t.Fatal("system should be removed after conversion")
+ }
+}
+
+func TestNormalizeGitHubCopilotResponsesInput_NonStringInputStringified(t *testing.T) {
+ t.Parallel()
+ body := []byte(`{"input":{"foo":"bar"}}`)
+ got := normalizeGitHubCopilotResponsesInput(body)
+ in := gjson.GetBytes(got, "input")
+ if in.Type != gjson.String {
+ t.Fatalf("input type = %v, want string", in.Type)
+ }
+ if !strings.Contains(in.String(), "foo") {
+ t.Fatalf("input = %q, want stringified object", in.String())
+ }
+}
+
+func TestNormalizeGitHubCopilotResponsesTools_FlattenFunctionTools(t *testing.T) {
+ t.Parallel()
+ body := []byte(`{"tools":[{"type":"function","function":{"name":"sum","description":"d","parameters":{"type":"object"}}},{"type":"web_search"}]}`)
+ got := normalizeGitHubCopilotResponsesTools(body)
+ tools := gjson.GetBytes(got, "tools").Array()
+ if len(tools) != 1 {
+ t.Fatalf("tools len = %d, want 1", len(tools))
+ }
+ if tools[0].Get("name").String() != "sum" {
+ t.Fatalf("tools[0].name = %q, want sum", tools[0].Get("name").String())
+ }
+ if !tools[0].Get("parameters").Exists() {
+ t.Fatal("expected parameters to be preserved")
+ }
+}
+
+func TestNormalizeGitHubCopilotResponsesTools_ClaudeFormatTools(t *testing.T) {
+ t.Parallel()
+ body := []byte(`{"tools":[{"name":"Bash","description":"Run commands","input_schema":{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}},{"name":"Read","description":"Read files","input_schema":{"type":"object","properties":{"path":{"type":"string"}}}}]}`)
+ got := normalizeGitHubCopilotResponsesTools(body)
+ tools := gjson.GetBytes(got, "tools").Array()
+ if len(tools) != 2 {
+ t.Fatalf("tools len = %d, want 2", len(tools))
+ }
+ if tools[0].Get("type").String() != "function" {
+ t.Fatalf("tools[0].type = %q, want function", tools[0].Get("type").String())
+ }
+ if tools[0].Get("name").String() != "Bash" {
+ t.Fatalf("tools[0].name = %q, want Bash", tools[0].Get("name").String())
+ }
+ if tools[0].Get("description").String() != "Run commands" {
+ t.Fatalf("tools[0].description = %q, want 'Run commands'", tools[0].Get("description").String())
+ }
+ if !tools[0].Get("parameters").Exists() {
+ t.Fatal("expected parameters to be set from input_schema")
+ }
+ if tools[0].Get("parameters.properties.command").Exists() != true {
+ t.Fatal("expected parameters.properties.command to exist")
+ }
+ if tools[1].Get("name").String() != "Read" {
+ t.Fatalf("tools[1].name = %q, want Read", tools[1].Get("name").String())
+ }
+}
+
+func TestNormalizeGitHubCopilotResponsesTools_FlattenToolChoiceFunctionObject(t *testing.T) {
+ t.Parallel()
+ body := []byte(`{"tool_choice":{"type":"function","function":{"name":"sum"}}}`)
+ got := normalizeGitHubCopilotResponsesTools(body)
+ if gjson.GetBytes(got, "tool_choice.type").String() != "function" {
+ t.Fatalf("tool_choice.type = %q, want function", gjson.GetBytes(got, "tool_choice.type").String())
+ }
+ if gjson.GetBytes(got, "tool_choice.name").String() != "sum" {
+ t.Fatalf("tool_choice.name = %q, want sum", gjson.GetBytes(got, "tool_choice.name").String())
+ }
+}
+
+func TestNormalizeGitHubCopilotResponsesTools_InvalidToolChoiceDowngradeToAuto(t *testing.T) {
+ t.Parallel()
+ body := []byte(`{"tool_choice":{"type":"function"}}`)
+ got := normalizeGitHubCopilotResponsesTools(body)
+ if gjson.GetBytes(got, "tool_choice").String() != "auto" {
+ t.Fatalf("tool_choice = %s, want auto", gjson.GetBytes(got, "tool_choice").Raw)
+ }
+}
+
+func TestTranslateGitHubCopilotResponsesNonStreamToClaude_TextMapping(t *testing.T) {
+ t.Parallel()
+ resp := []byte(`{"id":"resp_1","model":"gpt-5-codex","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":3,"output_tokens":5}}`)
+ out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
+ if gjson.Get(out, "type").String() != "message" {
+ t.Fatalf("type = %q, want message", gjson.Get(out, "type").String())
+ }
+ if gjson.Get(out, "content.0.type").String() != "text" {
+ t.Fatalf("content.0.type = %q, want text", gjson.Get(out, "content.0.type").String())
+ }
+ if gjson.Get(out, "content.0.text").String() != "hello" {
+ t.Fatalf("content.0.text = %q, want hello", gjson.Get(out, "content.0.text").String())
+ }
+}
+
+func TestTranslateGitHubCopilotResponsesNonStreamToClaude_ToolUseMapping(t *testing.T) {
+ t.Parallel()
+ resp := []byte(`{"id":"resp_2","model":"gpt-5-codex","output":[{"type":"function_call","id":"fc_1","call_id":"call_1","name":"sum","arguments":"{\"a\":1}"}],"usage":{"input_tokens":1,"output_tokens":2}}`)
+ out := translateGitHubCopilotResponsesNonStreamToClaude(resp)
+ if gjson.Get(out, "content.0.type").String() != "tool_use" {
+ t.Fatalf("content.0.type = %q, want tool_use", gjson.Get(out, "content.0.type").String())
+ }
+ if gjson.Get(out, "content.0.name").String() != "sum" {
+ t.Fatalf("content.0.name = %q, want sum", gjson.Get(out, "content.0.name").String())
+ }
+ if gjson.Get(out, "stop_reason").String() != "tool_use" {
+ t.Fatalf("stop_reason = %q, want tool_use", gjson.Get(out, "stop_reason").String())
+ }
+}
+
+func TestTranslateGitHubCopilotResponsesStreamToClaude_TextLifecycle(t *testing.T) {
+ t.Parallel()
+ var param any
+
+ created := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5-codex"}}`), ¶m)
+ if len(created) == 0 || !strings.Contains(created[0], "message_start") {
+ t.Fatalf("created events = %#v, want message_start", created)
+ }
+
+ delta := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.output_text.delta","delta":"he"}`), ¶m)
+ joinedDelta := strings.Join(delta, "")
+ if !strings.Contains(joinedDelta, "content_block_start") || !strings.Contains(joinedDelta, "text_delta") {
+ t.Fatalf("delta events = %#v, want content_block_start + text_delta", delta)
+ }
+
+ completed := translateGitHubCopilotResponsesStreamToClaude([]byte(`data: {"type":"response.completed","response":{"usage":{"input_tokens":7,"output_tokens":9}}}`), ¶m)
+ joinedCompleted := strings.Join(completed, "")
+ if !strings.Contains(joinedCompleted, "message_delta") || !strings.Contains(joinedCompleted, "message_stop") {
+ t.Fatalf("completed events = %#v, want message_delta + message_stop", completed)
+ }
+}
+
+// --- Tests for X-Initiator detection logic (Problem L) ---
+
+func TestApplyHeaders_XInitiator_UserOnly(t *testing.T) {
+ t.Parallel()
+ e := &GitHubCopilotExecutor{}
+ req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
+ body := []byte(`{"messages":[{"role":"system","content":"sys"},{"role":"user","content":"hello"}]}`)
+ e.applyHeaders(req, "token", body)
+ if got := req.Header.Get("X-Initiator"); got != "user" {
+ t.Fatalf("X-Initiator = %q, want user", got)
+ }
+}
+
+func TestApplyHeaders_XInitiator_AgentWithAssistantAndUserToolResult(t *testing.T) {
+ t.Parallel()
+ e := &GitHubCopilotExecutor{}
+ req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
+ // Claude Code typical flow: last message is user (tool result), but has assistant in history
+ body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"assistant","content":"I will read the file"},{"role":"user","content":"tool result here"}]}`)
+ e.applyHeaders(req, "token", body)
+ if got := req.Header.Get("X-Initiator"); got != "agent" {
+ t.Fatalf("X-Initiator = %q, want agent (assistant exists in messages)", got)
+ }
+}
+
+func TestApplyHeaders_XInitiator_AgentWithToolRole(t *testing.T) {
+ t.Parallel()
+ e := &GitHubCopilotExecutor{}
+ req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
+ body := []byte(`{"messages":[{"role":"user","content":"hello"},{"role":"tool","content":"result"}]}`)
+ e.applyHeaders(req, "token", body)
+ if got := req.Header.Get("X-Initiator"); got != "agent" {
+ t.Fatalf("X-Initiator = %q, want agent (tool role exists)", got)
+ }
+}
+
+// --- Tests for x-github-api-version header (Problem M) ---
+
+func TestApplyHeaders_GitHubAPIVersion(t *testing.T) {
+ t.Parallel()
+ e := &GitHubCopilotExecutor{}
+ req, _ := http.NewRequest(http.MethodPost, "https://example.com", nil)
+ e.applyHeaders(req, "token", nil)
+ if got := req.Header.Get("X-Github-Api-Version"); got != "2025-04-01" {
+ t.Fatalf("X-Github-Api-Version = %q, want 2025-04-01", got)
+ }
+}
+
+// --- Tests for vision detection (Problem P) ---
+
+func TestDetectVisionContent_WithImageURL(t *testing.T) {
+ t.Parallel()
+ body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"describe"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc"}}]}]}`)
+ if !detectVisionContent(body) {
+ t.Fatal("expected vision content to be detected")
+ }
+}
+
+func TestDetectVisionContent_WithImageType(t *testing.T) {
+ t.Parallel()
+ body := []byte(`{"messages":[{"role":"user","content":[{"type":"image","source":{"data":"abc","media_type":"image/png"}}]}]}`)
+ if !detectVisionContent(body) {
+ t.Fatal("expected image type to be detected")
+ }
+}
+
+func TestDetectVisionContent_NoVision(t *testing.T) {
+ t.Parallel()
+ body := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
+ if detectVisionContent(body) {
+ t.Fatal("expected no vision content")
+ }
+}
+
+func TestDetectVisionContent_NoMessages(t *testing.T) {
+ t.Parallel()
+ // After Responses API normalization, messages is removed — detection should return false
+ body := []byte(`{"input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`)
+ if detectVisionContent(body) {
+ t.Fatal("expected no vision content when messages field is absent")
+ }
+}
diff --git a/internal/runtime/executor/kilo_executor.go b/internal/runtime/executor/kilo_executor.go
new file mode 100644
index 00000000..b2359319
--- /dev/null
+++ b/internal/runtime/executor/kilo_executor.go
@@ -0,0 +1,459 @@
+package executor
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
+ cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
+ cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
+ sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
+ log "github.com/sirupsen/logrus"
+ "github.com/tidwall/gjson"
+)
+
+// KiloExecutor handles requests to Kilo API.
+type KiloExecutor struct {
+ cfg *config.Config
+}
+
+// NewKiloExecutor creates a new Kilo executor instance.
+func NewKiloExecutor(cfg *config.Config) *KiloExecutor {
+ return &KiloExecutor{cfg: cfg}
+}
+
+// Identifier returns the unique identifier for this executor.
+func (e *KiloExecutor) Identifier() string { return "kilo" }
+
+// PrepareRequest prepares the HTTP request before execution.
+func (e *KiloExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
+ if req == nil {
+ return nil
+ }
+ accessToken, _ := kiloCredentials(auth)
+ if strings.TrimSpace(accessToken) == "" {
+ return fmt.Errorf("kilo: missing access token")
+ }
+
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+ var attrs map[string]string
+ if auth != nil {
+ attrs = auth.Attributes
+ }
+ util.ApplyCustomHeadersFromAttrs(req, attrs)
+ return nil
+}
+
+// HttpRequest executes a raw HTTP request.
+func (e *KiloExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
+ if req == nil {
+ return nil, fmt.Errorf("kilo executor: request is nil")
+ }
+ if ctx == nil {
+ ctx = req.Context()
+ }
+ httpReq := req.WithContext(ctx)
+ if err := e.PrepareRequest(httpReq, auth); err != nil {
+ return nil, err
+ }
+ httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
+ return httpClient.Do(httpReq)
+}
+
+// Execute performs a non-streaming request.
+func (e *KiloExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
+ baseModel := thinking.ParseSuffix(req.Model).ModelName
+
+ reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
+ defer reporter.trackFailure(ctx, &err)
+
+ accessToken, orgID := kiloCredentials(auth)
+ if accessToken == "" {
+ return resp, fmt.Errorf("kilo: missing access token")
+ }
+
+ from := opts.SourceFormat
+ to := sdktranslator.FromString("openai")
+ endpoint := "/api/openrouter/chat/completions"
+
+ originalPayloadSource := req.Payload
+ if len(opts.OriginalRequest) > 0 {
+ originalPayloadSource = opts.OriginalRequest
+ }
+ originalPayload := originalPayloadSource
+ originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
+ translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
+ requestedModel := payloadRequestedModel(opts, req.Model)
+ translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
+
+ translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
+ if err != nil {
+ return resp, err
+ }
+
+ url := "https://api.kilo.ai" + endpoint
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
+ if err != nil {
+ return resp, err
+ }
+ httpReq.Header.Set("Content-Type", "application/json")
+ httpReq.Header.Set("Authorization", "Bearer "+accessToken)
+ if orgID != "" {
+ httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
+ }
+ httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
+ var attrs map[string]string
+ if auth != nil {
+ attrs = auth.Attributes
+ }
+ util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
+
+ var authID, authLabel, authType, authValue string
+ if auth != nil {
+ authID = auth.ID
+ authLabel = auth.Label
+ authType, authValue = auth.AccountInfo()
+ }
+ recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
+ URL: url,
+ Method: http.MethodPost,
+ Headers: httpReq.Header.Clone(),
+ Body: translated,
+ Provider: e.Identifier(),
+ AuthID: authID,
+ AuthLabel: authLabel,
+ AuthType: authType,
+ AuthValue: authValue,
+ })
+
+ httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
+ httpResp, err := httpClient.Do(httpReq)
+ if err != nil {
+ recordAPIResponseError(ctx, e.cfg, err)
+ return resp, err
+ }
+ defer httpResp.Body.Close()
+
+ recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
+ if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
+ b, _ := io.ReadAll(httpResp.Body)
+ appendAPIResponseChunk(ctx, e.cfg, b)
+ err = statusErr{code: httpResp.StatusCode, msg: string(b)}
+ return resp, err
+ }
+
+ body, err := io.ReadAll(httpResp.Body)
+ if err != nil {
+ recordAPIResponseError(ctx, e.cfg, err)
+ return resp, err
+ }
+ appendAPIResponseChunk(ctx, e.cfg, body)
+ reporter.publish(ctx, parseOpenAIUsage(body))
+ reporter.ensurePublished(ctx)
+
+ var param any
+ out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
+ resp = cliproxyexecutor.Response{Payload: []byte(out)}
+ return resp, nil
+}
+
+// ExecuteStream performs a streaming request.
+func (e *KiloExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
+ baseModel := thinking.ParseSuffix(req.Model).ModelName
+
+ reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
+ defer reporter.trackFailure(ctx, &err)
+
+ accessToken, orgID := kiloCredentials(auth)
+ if accessToken == "" {
+ return nil, fmt.Errorf("kilo: missing access token")
+ }
+
+ from := opts.SourceFormat
+ to := sdktranslator.FromString("openai")
+ endpoint := "/api/openrouter/chat/completions"
+
+ originalPayloadSource := req.Payload
+ if len(opts.OriginalRequest) > 0 {
+ originalPayloadSource = opts.OriginalRequest
+ }
+ originalPayload := originalPayloadSource
+ originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
+ translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
+ requestedModel := payloadRequestedModel(opts, req.Model)
+ translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
+
+ translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
+ if err != nil {
+ return nil, err
+ }
+
+ url := "https://api.kilo.ai" + endpoint
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
+ if err != nil {
+ return nil, err
+ }
+ httpReq.Header.Set("Content-Type", "application/json")
+ httpReq.Header.Set("Authorization", "Bearer "+accessToken)
+ if orgID != "" {
+ httpReq.Header.Set("X-Kilocode-OrganizationID", orgID)
+ }
+ httpReq.Header.Set("User-Agent", "cli-proxy-kilo")
+ httpReq.Header.Set("Accept", "text/event-stream")
+ httpReq.Header.Set("Cache-Control", "no-cache")
+
+ var attrs map[string]string
+ if auth != nil {
+ attrs = auth.Attributes
+ }
+ util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
+
+ var authID, authLabel, authType, authValue string
+ if auth != nil {
+ authID = auth.ID
+ authLabel = auth.Label
+ authType, authValue = auth.AccountInfo()
+ }
+ recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
+ URL: url,
+ Method: http.MethodPost,
+ Headers: httpReq.Header.Clone(),
+ Body: translated,
+ Provider: e.Identifier(),
+ AuthID: authID,
+ AuthLabel: authLabel,
+ AuthType: authType,
+ AuthValue: authValue,
+ })
+
+ httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
+ httpResp, err := httpClient.Do(httpReq)
+ if err != nil {
+ recordAPIResponseError(ctx, e.cfg, err)
+ return nil, err
+ }
+
+ recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
+ if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
+ b, _ := io.ReadAll(httpResp.Body)
+ appendAPIResponseChunk(ctx, e.cfg, b)
+ httpResp.Body.Close()
+ err = statusErr{code: httpResp.StatusCode, msg: string(b)}
+ return nil, err
+ }
+
+ out := make(chan cliproxyexecutor.StreamChunk)
+ stream = out
+ go func() {
+ defer close(out)
+ defer httpResp.Body.Close()
+
+ scanner := bufio.NewScanner(httpResp.Body)
+ scanner.Buffer(nil, 52_428_800)
+ var param any
+ for scanner.Scan() {
+ line := scanner.Bytes()
+ appendAPIResponseChunk(ctx, e.cfg, line)
+ if detail, ok := parseOpenAIStreamUsage(line); ok {
+ reporter.publish(ctx, detail)
+ }
+ if len(line) == 0 {
+ continue
+ }
+ if !bytes.HasPrefix(line, []byte("data:")) {
+ continue
+ }
+ chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
+ for i := range chunks {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
+ }
+ }
+ if errScan := scanner.Err(); errScan != nil {
+ recordAPIResponseError(ctx, e.cfg, errScan)
+ reporter.publishFailure(ctx)
+ out <- cliproxyexecutor.StreamChunk{Err: errScan}
+ }
+ reporter.ensurePublished(ctx)
+ }()
+
+ return stream, nil
+}
+
+// Refresh validates the Kilo token.
+func (e *KiloExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
+ if auth == nil {
+ return nil, fmt.Errorf("missing auth")
+ }
+ return auth, nil
+}
+
+// CountTokens returns the token count for the given request.
+func (e *KiloExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
+ return cliproxyexecutor.Response{}, fmt.Errorf("kilo: count tokens not supported")
+}
+
+// kiloCredentials extracts access token and other info from auth.
+func kiloCredentials(auth *cliproxyauth.Auth) (accessToken, orgID string) {
+ if auth == nil {
+ return "", ""
+ }
+
+ // Prefer kilocode specific keys, then fall back to generic keys.
+ // Check metadata first, then attributes.
+ if auth.Metadata != nil {
+ if token, ok := auth.Metadata["kilocodeToken"].(string); ok && token != "" {
+ accessToken = token
+ } else if token, ok := auth.Metadata["access_token"].(string); ok && token != "" {
+ accessToken = token
+ }
+
+ if org, ok := auth.Metadata["kilocodeOrganizationId"].(string); ok && org != "" {
+ orgID = org
+ } else if org, ok := auth.Metadata["organization_id"].(string); ok && org != "" {
+ orgID = org
+ }
+ }
+
+ if accessToken == "" && auth.Attributes != nil {
+ if token := auth.Attributes["kilocodeToken"]; token != "" {
+ accessToken = token
+ } else if token := auth.Attributes["access_token"]; token != "" {
+ accessToken = token
+ }
+ }
+
+ if orgID == "" && auth.Attributes != nil {
+ if org := auth.Attributes["kilocodeOrganizationId"]; org != "" {
+ orgID = org
+ } else if org := auth.Attributes["organization_id"]; org != "" {
+ orgID = org
+ }
+ }
+
+ return accessToken, orgID
+}
+
+// FetchKiloModels fetches models from Kilo API.
+func FetchKiloModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *config.Config) []*registry.ModelInfo {
+ accessToken, orgID := kiloCredentials(auth)
+ if accessToken == "" {
+ log.Infof("kilo: no access token found, skipping dynamic model fetch (using static kilo/auto)")
+ return registry.GetKiloModels()
+ }
+
+ log.Debugf("kilo: fetching dynamic models (orgID: %s)", orgID)
+
+ httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0)
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.kilo.ai/api/openrouter/models", nil)
+ if err != nil {
+ log.Warnf("kilo: failed to create model fetch request: %v", err)
+ return registry.GetKiloModels()
+ }
+
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+ if orgID != "" {
+ req.Header.Set("X-Kilocode-OrganizationID", orgID)
+ }
+ req.Header.Set("User-Agent", "cli-proxy-kilo")
+
+ resp, err := httpClient.Do(req)
+ if err != nil {
+ if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
+ log.Warnf("kilo: fetch models canceled: %v", err)
+ } else {
+ log.Warnf("kilo: using static models (API fetch failed: %v)", err)
+ }
+ return registry.GetKiloModels()
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ log.Warnf("kilo: failed to read models response: %v", err)
+ return registry.GetKiloModels()
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ log.Warnf("kilo: fetch models failed: status %d, body: %s", resp.StatusCode, string(body))
+ return registry.GetKiloModels()
+ }
+
+ result := gjson.GetBytes(body, "data")
+ if !result.Exists() {
+ // Try root if data field is missing
+ result = gjson.ParseBytes(body)
+ if !result.IsArray() {
+ log.Debugf("kilo: response body: %s", string(body))
+ log.Warn("kilo: invalid API response format (expected array or data field with array)")
+ return registry.GetKiloModels()
+ }
+ }
+
+ var dynamicModels []*registry.ModelInfo
+ now := time.Now().Unix()
+ count := 0
+ totalCount := 0
+
+ result.ForEach(func(key, value gjson.Result) bool {
+ totalCount++
+ id := value.Get("id").String()
+ pIdxResult := value.Get("preferredIndex")
+ preferredIndex := pIdxResult.Int()
+
+ // Filter models where preferredIndex > 0 (Kilo-curated models)
+ if preferredIndex <= 0 {
+ return true
+ }
+
+ // Check if it's free. We look for :free suffix, is_free flag, or zero pricing.
+ isFree := strings.HasSuffix(id, ":free") || id == "giga-potato" || value.Get("is_free").Bool()
+ if !isFree {
+ // Check pricing as fallback
+ promptPricing := value.Get("pricing.prompt").String()
+ if promptPricing == "0" || promptPricing == "0.0" {
+ isFree = true
+ }
+ }
+
+ if !isFree {
+ log.Debugf("kilo: skipping curated paid model: %s", id)
+ return true
+ }
+
+ log.Debugf("kilo: found curated model: %s (preferredIndex: %d)", id, preferredIndex)
+
+ dynamicModels = append(dynamicModels, ®istry.ModelInfo{
+ ID: id,
+ DisplayName: value.Get("name").String(),
+ ContextLength: int(value.Get("context_length").Int()),
+ OwnedBy: "kilo",
+ Type: "kilo",
+ Object: "model",
+ Created: now,
+ })
+ count++
+ return true
+ })
+
+ log.Infof("kilo: fetched %d models from API, %d curated free (preferredIndex > 0)", totalCount, count)
+ if count == 0 && totalCount > 0 {
+ log.Warn("kilo: no curated free models found (check API response fields)")
+ }
+
+ staticModels := registry.GetKiloModels()
+ // Always include kilo/auto (first static model)
+ allModels := append(staticModels[:1], dynamicModels...)
+
+ return allModels
+}
+
diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go
new file mode 100644
index 00000000..e1a280b9
--- /dev/null
+++ b/internal/runtime/executor/kiro_executor.go
@@ -0,0 +1,4795 @@
+package executor
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/base64"
+ "encoding/binary"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "syscall"
+ "time"
+
+ "github.com/google/uuid"
+ kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ kiroclaude "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude"
+ kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
+ kiroopenai "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
+ cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
+ cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
+ sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ // Kiro API common constants
+ kiroContentType = "application/json"
+ kiroAcceptStream = "*/*"
+
+ // Event Stream frame size constants for boundary protection
+ // AWS Event Stream binary format: prelude (12 bytes) + headers + payload + message_crc (4 bytes)
+ // Prelude consists of: total_length (4) + headers_length (4) + prelude_crc (4)
+ minEventStreamFrameSize = 16 // Minimum: 4(total_len) + 4(headers_len) + 4(prelude_crc) + 4(message_crc)
+ maxEventStreamMsgSize = 10 << 20 // Maximum message length: 10MB
+
+ // Event Stream error type constants
+ ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable
+ ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed
+
+ // kiroUserAgent matches Amazon Q CLI style for User-Agent header
+ kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0"
+ // kiroFullUserAgent is the complete x-amz-user-agent header (Amazon Q CLI style)
+ kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI"
+
+ // Kiro IDE style headers for IDC auth
+ kiroIDEUserAgent = "aws-sdk-js/1.0.27 ua/2.1 os/win32#10.0.19044 lang/js md/nodejs#22.21.1 api/codewhispererstreaming#1.0.27 m/E"
+ kiroIDEAmzUserAgent = "aws-sdk-js/1.0.27"
+ kiroIDEAgentModeVibe = "vibe"
+
+ // Socket retry configuration constants
+ // Maximum number of retry attempts for socket/network errors
+ kiroSocketMaxRetries = 3
+ // Base delay between retry attempts (uses exponential backoff: delay * 2^attempt)
+ kiroSocketBaseRetryDelay = 1 * time.Second
+ // Maximum delay between retry attempts (cap for exponential backoff)
+ kiroSocketMaxRetryDelay = 30 * time.Second
+ // First token timeout for streaming responses (how long to wait for first response)
+ kiroFirstTokenTimeout = 15 * time.Second
+ // Streaming read timeout (how long to wait between chunks)
+ kiroStreamingReadTimeout = 300 * time.Second
+)
+
+// retryableHTTPStatusCodes defines HTTP status codes that are considered retryable.
+// Based on kiro2Api reference: 502 (Bad Gateway), 503 (Service Unavailable), 504 (Gateway Timeout)
+var retryableHTTPStatusCodes = map[int]bool{
+ 502: true, // Bad Gateway - upstream server error
+ 503: true, // Service Unavailable - server temporarily overloaded
+ 504: true, // Gateway Timeout - upstream server timeout
+}
+
+// Real-time usage estimation configuration
+// These control how often usage updates are sent during streaming
+var (
+ usageUpdateCharThreshold = 5000 // Send usage update every 5000 characters
+ usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first
+)
+
+// Global FingerprintManager for dynamic User-Agent generation per token
+// Each token gets a unique fingerprint on first use, which is cached for subsequent requests
+var (
+ globalFingerprintManager *kiroauth.FingerprintManager
+ globalFingerprintManagerOnce sync.Once
+)
+
+// getGlobalFingerprintManager returns the global FingerprintManager instance
+func getGlobalFingerprintManager() *kiroauth.FingerprintManager {
+ globalFingerprintManagerOnce.Do(func() {
+ globalFingerprintManager = kiroauth.NewFingerprintManager()
+ log.Infof("kiro: initialized global FingerprintManager for dynamic UA generation")
+ })
+ return globalFingerprintManager
+}
+
+// retryConfig holds configuration for socket retry logic.
+// Based on kiro2Api Python implementation patterns.
+type retryConfig struct {
+ MaxRetries int // Maximum number of retry attempts
+ BaseDelay time.Duration // Base delay between retries (exponential backoff)
+ MaxDelay time.Duration // Maximum delay cap
+ RetryableErrors []string // List of retryable error patterns
+ RetryableStatus map[int]bool // HTTP status codes to retry
+ FirstTokenTmout time.Duration // Timeout for first token in streaming
+ StreamReadTmout time.Duration // Timeout between stream chunks
+}
+
+// defaultRetryConfig returns the default retry configuration for Kiro socket operations.
+func defaultRetryConfig() retryConfig {
+ return retryConfig{
+ MaxRetries: kiroSocketMaxRetries,
+ BaseDelay: kiroSocketBaseRetryDelay,
+ MaxDelay: kiroSocketMaxRetryDelay,
+ RetryableStatus: retryableHTTPStatusCodes,
+ RetryableErrors: []string{
+ "connection reset",
+ "connection refused",
+ "broken pipe",
+ "EOF",
+ "timeout",
+ "temporary failure",
+ "no such host",
+ "network is unreachable",
+ "i/o timeout",
+ },
+ FirstTokenTmout: kiroFirstTokenTimeout,
+ StreamReadTmout: kiroStreamingReadTimeout,
+ }
+}
+
+// isRetryableError checks if an error is retryable based on error type and message.
+// Returns true for network timeouts, connection resets, and temporary failures.
+// Based on kiro2Api's retry logic patterns.
+func isRetryableError(err error) bool {
+ if err == nil {
+ return false
+ }
+
+ // Check for context cancellation - not retryable
+ if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
+ return false
+ }
+
+ // Check for net.Error (timeout, temporary)
+ var netErr net.Error
+ if errors.As(err, &netErr) {
+ if netErr.Timeout() {
+ log.Debugf("kiro: isRetryableError: network timeout detected")
+ return true
+ }
+ // Note: Temporary() is deprecated but still useful for some error types
+ }
+
+ // Check for specific syscall errors (connection reset, broken pipe, etc.)
+ var syscallErr syscall.Errno
+ if errors.As(err, &syscallErr) {
+ switch syscallErr {
+ case syscall.ECONNRESET: // Connection reset by peer
+ log.Debugf("kiro: isRetryableError: ECONNRESET detected")
+ return true
+ case syscall.ECONNREFUSED: // Connection refused
+ log.Debugf("kiro: isRetryableError: ECONNREFUSED detected")
+ return true
+ case syscall.EPIPE: // Broken pipe
+ log.Debugf("kiro: isRetryableError: EPIPE (broken pipe) detected")
+ return true
+ case syscall.ETIMEDOUT: // Connection timed out
+ log.Debugf("kiro: isRetryableError: ETIMEDOUT detected")
+ return true
+ case syscall.ENETUNREACH: // Network is unreachable
+ log.Debugf("kiro: isRetryableError: ENETUNREACH detected")
+ return true
+ case syscall.EHOSTUNREACH: // No route to host
+ log.Debugf("kiro: isRetryableError: EHOSTUNREACH detected")
+ return true
+ }
+ }
+
+ // Check for net.OpError wrapping other errors
+ var opErr *net.OpError
+ if errors.As(err, &opErr) {
+ log.Debugf("kiro: isRetryableError: net.OpError detected, op=%s", opErr.Op)
+ // Recursively check the wrapped error
+ if opErr.Err != nil {
+ return isRetryableError(opErr.Err)
+ }
+ return true
+ }
+
+ // Check error message for retryable patterns
+ errMsg := strings.ToLower(err.Error())
+ cfg := defaultRetryConfig()
+ for _, pattern := range cfg.RetryableErrors {
+ if strings.Contains(errMsg, pattern) {
+ log.Debugf("kiro: isRetryableError: pattern '%s' matched in error: %s", pattern, errMsg)
+ return true
+ }
+ }
+
+ // Check for EOF which may indicate connection was closed
+ if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
+ log.Debugf("kiro: isRetryableError: EOF/UnexpectedEOF detected")
+ return true
+ }
+
+ return false
+}
+
+// isRetryableHTTPStatus checks if an HTTP status code is retryable.
+// Based on kiro2Api: 502, 503, 504 are retryable server errors.
+func isRetryableHTTPStatus(statusCode int) bool {
+ return retryableHTTPStatusCodes[statusCode]
+}
+
+// calculateRetryDelay calculates the delay for the next retry attempt using exponential backoff.
+// delay = min(baseDelay * 2^attempt, maxDelay)
+// Adds ±30% jitter to prevent thundering herd.
+func calculateRetryDelay(attempt int, cfg retryConfig) time.Duration {
+ return kiroauth.ExponentialBackoffWithJitter(attempt, cfg.BaseDelay, cfg.MaxDelay)
+}
+
+// logRetryAttempt logs a retry attempt with relevant context.
+func logRetryAttempt(attempt, maxRetries int, reason string, delay time.Duration, endpoint string) {
+ log.Warnf("kiro: retry attempt %d/%d for %s, waiting %v before next attempt (endpoint: %s)",
+ attempt+1, maxRetries, reason, delay, endpoint)
+}
+
+// kiroHTTPClientPool provides a shared HTTP client with connection pooling for Kiro API.
+// This reduces connection overhead and improves performance for concurrent requests.
+// Based on kiro2Api's connection pooling pattern.
+var (
+ kiroHTTPClientPool *http.Client
+ kiroHTTPClientPoolOnce sync.Once
+)
+
+// getKiroPooledHTTPClient returns a shared HTTP client with optimized connection pooling.
+// The client is lazily initialized on first use and reused across requests.
+// This is especially beneficial for:
+// - Reducing TCP handshake overhead
+// - Enabling HTTP/2 multiplexing
+// - Better handling of keep-alive connections
+func getKiroPooledHTTPClient() *http.Client {
+ kiroHTTPClientPoolOnce.Do(func() {
+ transport := &http.Transport{
+ // Connection pool settings
+ MaxIdleConns: 100, // Max idle connections across all hosts
+ MaxIdleConnsPerHost: 20, // Max idle connections per host
+ MaxConnsPerHost: 50, // Max total connections per host
+ IdleConnTimeout: 90 * time.Second, // How long idle connections stay in pool
+
+ // Timeouts for connection establishment
+ DialContext: (&net.Dialer{
+ Timeout: 30 * time.Second, // TCP connection timeout
+ KeepAlive: 30 * time.Second, // TCP keep-alive interval
+ }).DialContext,
+
+ // TLS handshake timeout
+ TLSHandshakeTimeout: 10 * time.Second,
+
+ // Response header timeout
+ ResponseHeaderTimeout: 30 * time.Second,
+
+ // Expect 100-continue timeout
+ ExpectContinueTimeout: 1 * time.Second,
+
+ // Enable HTTP/2 when available
+ ForceAttemptHTTP2: true,
+ }
+
+ kiroHTTPClientPool = &http.Client{
+ Transport: transport,
+ // No global timeout - let individual requests set their own timeouts via context
+ }
+
+ log.Debugf("kiro: initialized pooled HTTP client (MaxIdleConns=%d, MaxIdleConnsPerHost=%d, MaxConnsPerHost=%d)",
+ transport.MaxIdleConns, transport.MaxIdleConnsPerHost, transport.MaxConnsPerHost)
+ })
+
+ return kiroHTTPClientPool
+}
+
+// newKiroHTTPClientWithPooling creates an HTTP client that uses connection pooling when appropriate.
+// It respects proxy configuration from auth or config, falling back to the pooled client.
+// This provides the best of both worlds: custom proxy support + connection reuse.
+func newKiroHTTPClientWithPooling(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
+ // Check if a proxy is configured - if so, we need a custom client
+ var proxyURL string
+ if auth != nil {
+ proxyURL = strings.TrimSpace(auth.ProxyURL)
+ }
+ if proxyURL == "" && cfg != nil {
+ proxyURL = strings.TrimSpace(cfg.ProxyURL)
+ }
+
+ // If proxy is configured, use the existing proxy-aware client (doesn't pool)
+ if proxyURL != "" {
+ log.Debugf("kiro: using proxy-aware HTTP client (proxy=%s)", proxyURL)
+ return newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
+ }
+
+ // No proxy - use pooled client for better performance
+ pooledClient := getKiroPooledHTTPClient()
+
+ // If timeout is specified, we need to wrap the pooled transport with timeout
+ if timeout > 0 {
+ return &http.Client{
+ Transport: pooledClient.Transport,
+ Timeout: timeout,
+ }
+ }
+
+ return pooledClient
+}
+
+// kiroEndpointConfig bundles endpoint URL with its compatible Origin and AmzTarget values.
+// This solves the "triple mismatch" problem where different endpoints require matching
+// Origin and X-Amz-Target header values.
+//
+// Based on reference implementations:
+// - amq2api-main: Uses Amazon Q endpoint with CLI origin and AmazonQDeveloperStreamingService target
+// - AIClient-2-API: Uses CodeWhisperer endpoint with AI_EDITOR origin and AmazonCodeWhispererStreamingService target
+type kiroEndpointConfig struct {
+ URL string // Endpoint URL
+ Origin string // Request Origin: "CLI" for Amazon Q quota, "AI_EDITOR" for Kiro IDE quota
+ AmzTarget string // X-Amz-Target header value
+ Name string // Endpoint name for logging
+}
+
+// kiroDefaultRegion is the default AWS region for Kiro API endpoints.
+// Used when no region is specified in auth metadata.
+const kiroDefaultRegion = "us-east-1"
+
+// extractRegionFromProfileARN extracts the AWS region from a ProfileARN.
+// ARN format: arn:aws:codewhisperer:REGION:ACCOUNT:profile/PROFILE_ID
+// Returns empty string if region cannot be extracted.
+func extractRegionFromProfileARN(profileArn string) string {
+ if profileArn == "" {
+ return ""
+ }
+ parts := strings.Split(profileArn, ":")
+ if len(parts) >= 4 && parts[3] != "" {
+ return parts[3]
+ }
+ return ""
+}
+
+// buildKiroEndpointConfigs creates endpoint configurations for the specified region.
+// This enables dynamic region support for Enterprise/IdC users in non-us-east-1 regions.
+//
+// Uses Q endpoint (q.{region}.amazonaws.com) as primary for ALL auth types:
+// - Works universally across all AWS regions (CodeWhisperer endpoint only exists in us-east-1)
+// - Uses /generateAssistantResponse path with AI_EDITOR origin
+// - Does NOT require X-Amz-Target header
+//
+// The AmzTarget field is kept for backward compatibility but should be empty
+// to indicate that the header should NOT be set.
+func buildKiroEndpointConfigs(region string) []kiroEndpointConfig {
+ if region == "" {
+ region = kiroDefaultRegion
+ }
+ return []kiroEndpointConfig{
+ {
+ // Primary: Q endpoint - works for all regions and auth types
+ URL: fmt.Sprintf("https://q.%s.amazonaws.com/generateAssistantResponse", region),
+ Origin: "AI_EDITOR",
+ AmzTarget: "", // Empty = don't set X-Amz-Target header
+ Name: "AmazonQ",
+ },
+ {
+ // Fallback: CodeWhisperer endpoint (legacy, only works in us-east-1)
+ URL: fmt.Sprintf("https://codewhisperer.%s.amazonaws.com/generateAssistantResponse", region),
+ Origin: "AI_EDITOR",
+ AmzTarget: "AmazonCodeWhispererStreamingService.GenerateAssistantResponse",
+ Name: "CodeWhisperer",
+ },
+ }
+}
+
+// resolveKiroAPIRegion determines the AWS region for Kiro API calls.
+// Region priority:
+// 1. auth.Metadata["api_region"] - explicit API region override
+// 2. ProfileARN region - extracted from arn:aws:service:REGION:account:resource
+// 3. kiroDefaultRegion (us-east-1) - fallback
+// Note: OIDC "region" is NOT used - it's for token refresh, not API calls
+func resolveKiroAPIRegion(auth *cliproxyauth.Auth) string {
+ if auth == nil || auth.Metadata == nil {
+ return kiroDefaultRegion
+ }
+ // Priority 1: Explicit api_region override
+ if r, ok := auth.Metadata["api_region"].(string); ok && r != "" {
+ log.Debugf("kiro: using region %s (source: api_region)", r)
+ return r
+ }
+ // Priority 2: Extract from ProfileARN
+ if profileArn, ok := auth.Metadata["profile_arn"].(string); ok && profileArn != "" {
+ if arnRegion := extractRegionFromProfileARN(profileArn); arnRegion != "" {
+ log.Debugf("kiro: using region %s (source: profile_arn)", arnRegion)
+ return arnRegion
+ }
+ }
+ // Note: OIDC "region" field is NOT used for API endpoint
+ // Kiro API only exists in us-east-1, while OIDC region can vary (e.g., ap-northeast-2)
+ // Using OIDC region for API calls causes DNS failures
+ log.Debugf("kiro: using region %s (source: default)", kiroDefaultRegion)
+ return kiroDefaultRegion
+}
+
+// kiroEndpointConfigs is kept for backward compatibility with default us-east-1 region.
+// Prefer using buildKiroEndpointConfigs(region) for dynamic region support.
+var kiroEndpointConfigs = buildKiroEndpointConfigs(kiroDefaultRegion)
+
+// getKiroEndpointConfigs returns the list of Kiro API endpoint configurations to try in order.
+// Supports dynamic region based on auth metadata "api_region", "profile_arn", or "region" field.
+// Supports reordering based on "preferred_endpoint" in auth metadata/attributes.
+//
+// Region priority:
+// 1. auth.Metadata["api_region"] - explicit API region override
+// 2. ProfileARN region - extracted from arn:aws:service:REGION:account:resource
+// 3. kiroDefaultRegion (us-east-1) - fallback
+// Note: OIDC "region" is NOT used - it's for token refresh, not API calls
+func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig {
+ if auth == nil {
+ return kiroEndpointConfigs
+ }
+
+ // Determine API region using shared resolution logic
+ region := resolveKiroAPIRegion(auth)
+
+ // Build endpoint configs for the specified region
+ endpointConfigs := buildKiroEndpointConfigs(region)
+
+ // For IDC auth, use Q endpoint with AI_EDITOR origin
+ // IDC tokens work with Q endpoint using Bearer auth
+ // The difference is only in how tokens are refreshed (OIDC with clientId/clientSecret for IDC)
+ // NOT in how API calls are made - both Social and IDC use the same endpoint/origin
+ if auth.Metadata != nil {
+ authMethod, _ := auth.Metadata["auth_method"].(string)
+ if strings.ToLower(authMethod) == "idc" {
+ log.Debugf("kiro: IDC auth, using Q endpoint (region: %s)", region)
+ return endpointConfigs
+ }
+ }
+
+ // Check for preference
+ var preference string
+ if auth.Metadata != nil {
+ if p, ok := auth.Metadata["preferred_endpoint"].(string); ok {
+ preference = p
+ }
+ }
+ // Check attributes as fallback (e.g. from HTTP headers)
+ if preference == "" && auth.Attributes != nil {
+ preference = auth.Attributes["preferred_endpoint"]
+ }
+
+ if preference == "" {
+ return endpointConfigs
+ }
+
+ preference = strings.ToLower(strings.TrimSpace(preference))
+
+ // Create new slice to avoid modifying global state
+ var sorted []kiroEndpointConfig
+ var remaining []kiroEndpointConfig
+
+ for _, cfg := range endpointConfigs {
+ name := strings.ToLower(cfg.Name)
+ // Check for matches
+ // CodeWhisperer aliases: codewhisperer, ide
+ // AmazonQ aliases: amazonq, q, cli
+ isMatch := false
+ if (preference == "codewhisperer" || preference == "ide") && name == "codewhisperer" {
+ isMatch = true
+ } else if (preference == "amazonq" || preference == "q" || preference == "cli") && name == "amazonq" {
+ isMatch = true
+ }
+
+ if isMatch {
+ sorted = append(sorted, cfg)
+ } else {
+ remaining = append(remaining, cfg)
+ }
+ }
+
+ // If preference didn't match anything, return default
+ if len(sorted) == 0 {
+ return endpointConfigs
+ }
+
+ // Combine: preferred first, then others
+ return append(sorted, remaining...)
+}
+
+// KiroExecutor handles requests to AWS CodeWhisperer (Kiro) API.
+type KiroExecutor struct {
+ cfg *config.Config
+ refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions
+}
+
+// isIDCAuth checks if the auth uses IDC (Identity Center) authentication method.
+func isIDCAuth(auth *cliproxyauth.Auth) bool {
+ if auth == nil || auth.Metadata == nil {
+ return false
+ }
+ authMethod, _ := auth.Metadata["auth_method"].(string)
+ return strings.ToLower(authMethod) == "idc"
+}
+
+// buildKiroPayloadForFormat builds the Kiro API payload based on the source format.
+// This is critical because OpenAI and Claude formats have different tool structures:
+// - OpenAI: tools[].function.name, tools[].function.description
+// - Claude: tools[].name, tools[].description
+// headers parameter allows checking Anthropic-Beta header for thinking mode detection.
+// Returns the serialized JSON payload and a boolean indicating whether thinking mode was injected.
+func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, sourceFormat sdktranslator.Format, headers http.Header) ([]byte, bool) {
+ switch sourceFormat.String() {
+ case "openai":
+ log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String())
+ return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil)
+ case "kiro":
+ // Body is already in Kiro format — pass through directly
+ log.Debugf("kiro: body already in Kiro format, passing through directly")
+ return body, false
+ default:
+ // Default to Claude format
+ log.Debugf("kiro: using Claude payload builder for source format: %s", sourceFormat.String())
+ return kiroclaude.BuildKiroPayload(body, modelID, profileArn, origin, isAgentic, isChatOnly, headers, nil)
+ }
+}
+
+// NewKiroExecutor creates a new Kiro executor instance.
+func NewKiroExecutor(cfg *config.Config) *KiroExecutor {
+ return &KiroExecutor{cfg: cfg}
+}
+
+// Identifier returns the unique identifier for this executor.
+func (e *KiroExecutor) Identifier() string { return "kiro" }
+
+// applyDynamicFingerprint applies token-specific fingerprint headers to the request
+// For IDC auth, uses dynamic fingerprint-based User-Agent
+// For other auth types, uses static Amazon Q CLI style headers
+func applyDynamicFingerprint(req *http.Request, auth *cliproxyauth.Auth) {
+ if isIDCAuth(auth) {
+ // Get token-specific fingerprint for dynamic UA generation
+ tokenKey := getTokenKey(auth)
+ fp := getGlobalFingerprintManager().GetFingerprint(tokenKey)
+
+ // Use fingerprint-generated dynamic User-Agent
+ req.Header.Set("User-Agent", fp.BuildUserAgent())
+ req.Header.Set("X-Amz-User-Agent", fp.BuildAmzUserAgent())
+ req.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe)
+
+ log.Debugf("kiro: using dynamic fingerprint for token %s (SDK:%s, OS:%s/%s, Kiro:%s)",
+ tokenKey[:8]+"...", fp.SDKVersion, fp.OSType, fp.OSVersion, fp.KiroVersion)
+ } else {
+ // Use static Amazon Q CLI style headers for non-IDC auth
+ req.Header.Set("User-Agent", kiroUserAgent)
+ req.Header.Set("X-Amz-User-Agent", kiroFullUserAgent)
+ }
+}
+
+// PrepareRequest prepares the HTTP request before execution.
+func (e *KiroExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
+ if req == nil {
+ return nil
+ }
+ accessToken, _ := kiroCredentials(auth)
+ if strings.TrimSpace(accessToken) == "" {
+ return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
+ }
+
+ // Apply dynamic fingerprint-based headers
+ applyDynamicFingerprint(req, auth)
+
+ req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
+ req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
+ req.Header.Set("Authorization", "Bearer "+accessToken)
+ var attrs map[string]string
+ if auth != nil {
+ attrs = auth.Attributes
+ }
+ util.ApplyCustomHeadersFromAttrs(req, attrs)
+ return nil
+}
+
+// HttpRequest injects Kiro credentials into the request and executes it.
+func (e *KiroExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
+ if req == nil {
+ return nil, fmt.Errorf("kiro executor: request is nil")
+ }
+ if ctx == nil {
+ ctx = req.Context()
+ }
+ httpReq := req.WithContext(ctx)
+ if errPrepare := e.PrepareRequest(httpReq, auth); errPrepare != nil {
+ return nil, errPrepare
+ }
+ httpClient := newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 0)
+ return httpClient.Do(httpReq)
+}
+
+// getTokenKey returns a unique key for rate limiting based on auth credentials.
+// Uses auth ID if available, otherwise falls back to a hash of the access token.
+func getTokenKey(auth *cliproxyauth.Auth) string {
+ if auth != nil && auth.ID != "" {
+ return auth.ID
+ }
+ accessToken, _ := kiroCredentials(auth)
+ if len(accessToken) > 16 {
+ return accessToken[:16]
+ }
+ return accessToken
+}
+
+// Execute sends the request to Kiro API and returns the response.
+// Supports automatic token refresh on 401/403 errors.
+func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
+ accessToken, profileArn := kiroCredentials(auth)
+ if accessToken == "" {
+ return resp, fmt.Errorf("kiro: access token not found in auth")
+ }
+
+ // Rate limiting: get token key for tracking
+ tokenKey := getTokenKey(auth)
+ rateLimiter := kiroauth.GetGlobalRateLimiter()
+ cooldownMgr := kiroauth.GetGlobalCooldownManager()
+
+ // Check if token is in cooldown period
+ if cooldownMgr.IsInCooldown(tokenKey) {
+ remaining := cooldownMgr.GetRemainingCooldown(tokenKey)
+ reason := cooldownMgr.GetCooldownReason(tokenKey)
+ log.Warnf("kiro: token %s is in cooldown (reason: %s), remaining: %v", tokenKey, reason, remaining)
+ return resp, fmt.Errorf("kiro: token is in cooldown for %v (reason: %s)", remaining, reason)
+ }
+
+ // Wait for rate limiter before proceeding
+ log.Debugf("kiro: waiting for rate limiter for token %s", tokenKey)
+ rateLimiter.WaitForToken(tokenKey)
+ log.Debugf("kiro: rate limiter cleared for token %s", tokenKey)
+
+ // Check if token is expired before making request (covers both normal and web_search paths)
+ if e.isTokenExpired(accessToken) {
+ log.Infof("kiro: access token expired, attempting recovery")
+
+ // 方案 B: 先尝试从文件重新加载 token(后台刷新器可能已更新文件)
+ reloadedAuth, reloadErr := e.reloadAuthFromFile(auth)
+ if reloadErr == nil && reloadedAuth != nil {
+ // 文件中有更新的 token,使用它
+ auth = reloadedAuth
+ accessToken, profileArn = kiroCredentials(auth)
+ log.Infof("kiro: recovered token from file (background refresh), expires_at: %v", auth.Metadata["expires_at"])
+ } else {
+ // 文件中的 token 也过期了,执行主动刷新
+ log.Debugf("kiro: file reload failed (%v), attempting active refresh", reloadErr)
+ refreshedAuth, refreshErr := e.Refresh(ctx, auth)
+ if refreshErr != nil {
+ log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
+ } else if refreshedAuth != nil {
+ auth = refreshedAuth
+ // Persist the refreshed auth to file so subsequent requests use it
+ if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
+ log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
+ }
+ accessToken, profileArn = kiroCredentials(auth)
+ log.Infof("kiro: token refreshed successfully before request")
+ }
+ }
+ }
+
+ // Check for pure web_search request
+ // Route to MCP endpoint instead of normal Kiro API
+ if kiroclaude.HasWebSearchTool(req.Payload) {
+ log.Infof("kiro: detected pure web_search request (non-stream), routing to MCP endpoint")
+ return e.handleWebSearch(ctx, auth, req, opts, accessToken, profileArn)
+ }
+
+ reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
+ defer reporter.trackFailure(ctx, &err)
+
+ from := opts.SourceFormat
+ to := sdktranslator.FromString("kiro")
+ body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
+
+ kiroModelID := e.mapModelToKiro(req.Model)
+
+ // Determine agentic mode and effective profile ARN using helper functions
+ isAgentic, isChatOnly := determineAgenticMode(req.Model)
+ effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
+
+ // Execute with retry on 401/403 and 429 (quota exhausted)
+ // Note: currentOrigin and kiroPayload are built inside executeWithRetry for each endpoint
+ resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, to, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
+ return resp, err
+}
+
+// executeWithRetry performs the actual HTTP request with automatic retry on auth errors.
+// Supports automatic fallback between endpoints with different quotas:
+// - Amazon Q endpoint (CLI origin) uses Amazon Q Developer quota
+// - CodeWhisperer endpoint (AI_EDITOR origin) uses Kiro IDE quota
+// Also supports multi-endpoint fallback similar to Antigravity implementation.
+// tokenKey is used for rate limiting and cooldown tracking.
+func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from, to sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool, tokenKey string) (cliproxyexecutor.Response, error) {
+ var resp cliproxyexecutor.Response
+ maxRetries := 2 // Allow retries for token refresh + endpoint fallback
+ rateLimiter := kiroauth.GetGlobalRateLimiter()
+ cooldownMgr := kiroauth.GetGlobalCooldownManager()
+ endpointConfigs := getKiroEndpointConfigs(auth)
+ var last429Err error
+
+ for endpointIdx := 0; endpointIdx < len(endpointConfigs); endpointIdx++ {
+ endpointConfig := endpointConfigs[endpointIdx]
+ url := endpointConfig.URL
+ // Use this endpoint's compatible Origin (critical for avoiding 403 errors)
+ currentOrigin = endpointConfig.Origin
+
+ // Rebuild payload with the correct origin for this endpoint
+ // Each endpoint requires its matching Origin value in the request body
+ kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
+
+ log.Debugf("kiro: trying endpoint %d/%d: %s (Name: %s, Origin: %s)",
+ endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin)
+
+ for attempt := 0; attempt <= maxRetries; attempt++ {
+ // Apply human-like delay before first request (not on retries)
+ // This mimics natural user behavior patterns
+ if attempt == 0 && endpointIdx == 0 {
+ kiroauth.ApplyHumanLikeDelay()
+ }
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload))
+ if err != nil {
+ return resp, err
+ }
+
+ httpReq.Header.Set("Content-Type", kiroContentType)
+ httpReq.Header.Set("Accept", kiroAcceptStream)
+ // Only set X-Amz-Target if specified (Q endpoint doesn't require it)
+ if endpointConfig.AmzTarget != "" {
+ httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
+ }
+ // Kiro-specific headers
+ httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe)
+ httpReq.Header.Set("x-amzn-codewhisperer-optout", "true")
+
+ // Apply dynamic fingerprint-based headers
+ applyDynamicFingerprint(httpReq, auth)
+
+ httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
+ httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
+
+ // Bearer token authentication for all auth types (Builder ID, IDC, social, etc.)
+ httpReq.Header.Set("Authorization", "Bearer "+accessToken)
+
+ var attrs map[string]string
+ if auth != nil {
+ attrs = auth.Attributes
+ }
+ util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
+
+ var authID, authLabel, authType, authValue string
+ if auth != nil {
+ authID = auth.ID
+ authLabel = auth.Label
+ authType, authValue = auth.AccountInfo()
+ }
+ recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
+ URL: url,
+ Method: http.MethodPost,
+ Headers: httpReq.Header.Clone(),
+ Body: kiroPayload,
+ Provider: e.Identifier(),
+ AuthID: authID,
+ AuthLabel: authLabel,
+ AuthType: authType,
+ AuthValue: authValue,
+ })
+
+ httpClient := newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 120*time.Second)
+ httpResp, err := httpClient.Do(httpReq)
+ if err != nil {
+ // Check for context cancellation first - client disconnected, not a server error
+ // Use 499 (Client Closed Request - nginx convention) instead of 500
+ if errors.Is(err, context.Canceled) {
+ log.Debugf("kiro: request canceled by client (context.Canceled)")
+ return resp, statusErr{code: 499, msg: "client canceled request"}
+ }
+
+ // Check for context deadline exceeded - request timed out
+ // Return 504 Gateway Timeout instead of 500
+ if errors.Is(err, context.DeadlineExceeded) {
+ log.Debugf("kiro: request timed out (context.DeadlineExceeded)")
+ return resp, statusErr{code: http.StatusGatewayTimeout, msg: "upstream request timed out"}
+ }
+
+ recordAPIResponseError(ctx, e.cfg, err)
+
+ // Enhanced socket retry: Check if error is retryable (network timeout, connection reset, etc.)
+ retryCfg := defaultRetryConfig()
+ if isRetryableError(err) && attempt < retryCfg.MaxRetries {
+ delay := calculateRetryDelay(attempt, retryCfg)
+ logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("socket error: %v", err), delay, endpointConfig.Name)
+ time.Sleep(delay)
+ continue
+ }
+
+ return resp, err
+ }
+ recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
+
+ // Handle 429 errors (quota exhausted) - try next endpoint
+ // Each endpoint has its own quota pool, so we can try different endpoints
+ if httpResp.StatusCode == 429 {
+ respBody, _ := io.ReadAll(httpResp.Body)
+ _ = httpResp.Body.Close()
+ appendAPIResponseChunk(ctx, e.cfg, respBody)
+
+ // Record failure and set cooldown for 429
+ rateLimiter.MarkTokenFailed(tokenKey)
+ cooldownDuration := kiroauth.CalculateCooldownFor429(attempt)
+ cooldownMgr.SetCooldown(tokenKey, cooldownDuration, kiroauth.CooldownReason429)
+ log.Warnf("kiro: rate limit hit (429), token %s set to cooldown for %v", tokenKey, cooldownDuration)
+
+ // Preserve last 429 so callers can correctly backoff when all endpoints are exhausted
+ last429Err = statusErr{code: httpResp.StatusCode, msg: string(respBody)}
+
+ log.Warnf("kiro: %s endpoint quota exhausted (429), will try next endpoint, body: %s",
+ endpointConfig.Name, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody))
+
+ // Break inner retry loop to try next endpoint (which has different quota)
+ break
+ }
+
+ // Handle 5xx server errors with exponential backoff retry
+ // Enhanced: Use retryConfig for consistent retry behavior
+ if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 {
+ respBody, _ := io.ReadAll(httpResp.Body)
+ _ = httpResp.Body.Close()
+ appendAPIResponseChunk(ctx, e.cfg, respBody)
+
+ retryCfg := defaultRetryConfig()
+ // Check if this specific 5xx code is retryable (502, 503, 504)
+ if isRetryableHTTPStatus(httpResp.StatusCode) && attempt < retryCfg.MaxRetries {
+ delay := calculateRetryDelay(attempt, retryCfg)
+ logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("HTTP %d", httpResp.StatusCode), delay, endpointConfig.Name)
+ time.Sleep(delay)
+ continue
+ } else if attempt < maxRetries {
+ // Fallback for other 5xx errors (500, 501, etc.)
+ backoff := time.Duration(1< 30*time.Second {
+ backoff = 30 * time.Second
+ }
+ log.Warnf("kiro: server error %d, retrying in %v (attempt %d/%d)", httpResp.StatusCode, backoff, attempt+1, maxRetries)
+ time.Sleep(backoff)
+ continue
+ }
+ log.Errorf("kiro: server error %d after %d retries", httpResp.StatusCode, maxRetries)
+ return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
+ }
+
+ // Handle 401 errors with token refresh and retry
+ // 401 = Unauthorized (token expired/invalid) - refresh token
+ if httpResp.StatusCode == 401 {
+ respBody, _ := io.ReadAll(httpResp.Body)
+ _ = httpResp.Body.Close()
+ appendAPIResponseChunk(ctx, e.cfg, respBody)
+
+ log.Warnf("kiro: received 401 error, attempting token refresh")
+ refreshedAuth, refreshErr := e.Refresh(ctx, auth)
+ if refreshErr != nil {
+ log.Errorf("kiro: token refresh failed: %v", refreshErr)
+ return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
+ }
+
+ if refreshedAuth != nil {
+ auth = refreshedAuth
+ // Persist the refreshed auth to file so subsequent requests use it
+ if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
+ log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
+ // Continue anyway - the token is valid for this request
+ }
+ accessToken, profileArn = kiroCredentials(auth)
+ // Rebuild payload with new profile ARN if changed
+ kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
+ if attempt < maxRetries {
+ log.Infof("kiro: token refreshed successfully, retrying request (attempt %d/%d)", attempt+1, maxRetries+1)
+ continue
+ }
+ log.Infof("kiro: token refreshed successfully, no retries remaining")
+ }
+
+ log.Warnf("kiro request error, status: 401, body: %s", summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody))
+ return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
+ }
+
+ // Handle 402 errors - Monthly Limit Reached
+ if httpResp.StatusCode == 402 {
+ respBody, _ := io.ReadAll(httpResp.Body)
+ _ = httpResp.Body.Close()
+ appendAPIResponseChunk(ctx, e.cfg, respBody)
+
+ log.Warnf("kiro: received 402 (monthly limit). Upstream body: %s", string(respBody))
+
+ // Return upstream error body directly
+ return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
+ }
+
+ // Handle 403 errors - Access Denied / Token Expired
+ // Do NOT switch endpoints for 403 errors
+ if httpResp.StatusCode == 403 {
+ respBody, _ := io.ReadAll(httpResp.Body)
+ _ = httpResp.Body.Close()
+ appendAPIResponseChunk(ctx, e.cfg, respBody)
+
+ // Log the 403 error details for debugging
+ log.Warnf("kiro: received 403 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody))
+
+ respBodyStr := string(respBody)
+
+ // Check for SUSPENDED status - return immediately without retry
+ if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") {
+ // Set long cooldown for suspended accounts
+ rateLimiter.CheckAndMarkSuspended(tokenKey, respBodyStr)
+ cooldownMgr.SetCooldown(tokenKey, kiroauth.LongCooldown, kiroauth.CooldownReasonSuspended)
+ log.Errorf("kiro: account is suspended, token %s set to cooldown for %v", tokenKey, kiroauth.LongCooldown)
+ return resp, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)}
+ }
+
+ // Check if this looks like a token-related 403 (some APIs return 403 for expired tokens)
+ isTokenRelated := strings.Contains(respBodyStr, "token") ||
+ strings.Contains(respBodyStr, "expired") ||
+ strings.Contains(respBodyStr, "invalid") ||
+ strings.Contains(respBodyStr, "unauthorized")
+
+ if isTokenRelated && attempt < maxRetries {
+ log.Warnf("kiro: 403 appears token-related, attempting token refresh")
+ refreshedAuth, refreshErr := e.Refresh(ctx, auth)
+ if refreshErr != nil {
+ log.Errorf("kiro: token refresh failed: %v", refreshErr)
+ // Token refresh failed - return error immediately
+ return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
+ }
+ if refreshedAuth != nil {
+ auth = refreshedAuth
+ // Persist the refreshed auth to file so subsequent requests use it
+ if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
+ log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
+ // Continue anyway - the token is valid for this request
+ }
+ accessToken, profileArn = kiroCredentials(auth)
+ kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
+ log.Infof("kiro: token refreshed for 403, retrying request")
+ continue
+ }
+ }
+
+ // For non-token 403 or after max retries, return error immediately
+ // Do NOT switch endpoints for 403 errors
+ log.Warnf("kiro: 403 error, returning immediately (no endpoint switch)")
+ return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
+ }
+
+ if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
+ b, _ := io.ReadAll(httpResp.Body)
+ appendAPIResponseChunk(ctx, e.cfg, b)
+ log.Debugf("kiro request error, status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
+ err = statusErr{code: httpResp.StatusCode, msg: string(b)}
+ if errClose := httpResp.Body.Close(); errClose != nil {
+ log.Errorf("response body close error: %v", errClose)
+ }
+ return resp, err
+ }
+
+ defer func() {
+ if errClose := httpResp.Body.Close(); errClose != nil {
+ log.Errorf("response body close error: %v", errClose)
+ }
+ }()
+
+ content, toolUses, usageInfo, stopReason, err := e.parseEventStream(httpResp.Body)
+ if err != nil {
+ recordAPIResponseError(ctx, e.cfg, err)
+ return resp, err
+ }
+
+ // Fallback for usage if missing from upstream
+
+ // 1. Estimate InputTokens if missing
+ if usageInfo.InputTokens == 0 {
+ if enc, encErr := getTokenizer(req.Model); encErr == nil {
+ if inp, countErr := countOpenAIChatTokens(enc, opts.OriginalRequest); countErr == nil {
+ usageInfo.InputTokens = inp
+ }
+ }
+ }
+
+ // 2. Estimate OutputTokens if missing and content is available
+ if usageInfo.OutputTokens == 0 && len(content) > 0 {
+ // Use tiktoken for more accurate output token calculation
+ if enc, encErr := getTokenizer(req.Model); encErr == nil {
+ if tokenCount, countErr := enc.Count(content); countErr == nil {
+ usageInfo.OutputTokens = int64(tokenCount)
+ }
+ }
+ // Fallback to character count estimation if tiktoken fails
+ if usageInfo.OutputTokens == 0 {
+ usageInfo.OutputTokens = int64(len(content) / 4)
+ if usageInfo.OutputTokens == 0 {
+ usageInfo.OutputTokens = 1
+ }
+ }
+ }
+
+ // 3. Update TotalTokens
+ usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens
+
+ appendAPIResponseChunk(ctx, e.cfg, []byte(content))
+ reporter.publish(ctx, usageInfo)
+
+ // Record success for rate limiting
+ rateLimiter.MarkTokenSuccess(tokenKey)
+ log.Debugf("kiro: request successful, token %s marked as success", tokenKey)
+
+ // Build response in Claude format for Kiro translator
+ // stopReason is extracted from upstream response by parseEventStream
+ requestedModel := payloadRequestedModel(opts, req.Model)
+ kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, requestedModel, usageInfo, stopReason)
+ out := sdktranslator.TranslateNonStream(ctx, to, from, requestedModel, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil)
+ resp = cliproxyexecutor.Response{Payload: []byte(out)}
+ return resp, nil
+ }
+ // Inner retry loop exhausted for this endpoint, try next endpoint
+ // Note: This code is unreachable because all paths in the inner loop
+ // either return or continue. Kept as comment for documentation.
+ }
+
+ // All endpoints exhausted
+ if last429Err != nil {
+ return resp, last429Err
+ }
+ return resp, fmt.Errorf("kiro: all endpoints exhausted")
+}
+
+// ExecuteStream handles streaming requests to Kiro API.
+// Supports automatic token refresh on 401/403 errors and quota fallback on 429.
+func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
+ accessToken, profileArn := kiroCredentials(auth)
+ if accessToken == "" {
+ return nil, fmt.Errorf("kiro: access token not found in auth")
+ }
+
+ // Rate limiting: get token key for tracking
+ tokenKey := getTokenKey(auth)
+ rateLimiter := kiroauth.GetGlobalRateLimiter()
+ cooldownMgr := kiroauth.GetGlobalCooldownManager()
+
+ // Check if token is in cooldown period
+ if cooldownMgr.IsInCooldown(tokenKey) {
+ remaining := cooldownMgr.GetRemainingCooldown(tokenKey)
+ reason := cooldownMgr.GetCooldownReason(tokenKey)
+ log.Warnf("kiro: token %s is in cooldown (reason: %s), remaining: %v", tokenKey, reason, remaining)
+ return nil, fmt.Errorf("kiro: token is in cooldown for %v (reason: %s)", remaining, reason)
+ }
+
+ // Wait for rate limiter before proceeding
+ log.Debugf("kiro: stream waiting for rate limiter for token %s", tokenKey)
+ rateLimiter.WaitForToken(tokenKey)
+ log.Debugf("kiro: stream rate limiter cleared for token %s", tokenKey)
+
+ // Check if token is expired before making request (covers both normal and web_search paths)
+ if e.isTokenExpired(accessToken) {
+ log.Infof("kiro: access token expired, attempting recovery before stream request")
+
+ // 方案 B: 先尝试从文件重新加载 token(后台刷新器可能已更新文件)
+ reloadedAuth, reloadErr := e.reloadAuthFromFile(auth)
+ if reloadErr == nil && reloadedAuth != nil {
+ // 文件中有更新的 token,使用它
+ auth = reloadedAuth
+ accessToken, profileArn = kiroCredentials(auth)
+ log.Infof("kiro: recovered token from file (background refresh) for stream, expires_at: %v", auth.Metadata["expires_at"])
+ } else {
+ // 文件中的 token 也过期了,执行主动刷新
+ log.Debugf("kiro: file reload failed (%v), attempting active refresh for stream", reloadErr)
+ refreshedAuth, refreshErr := e.Refresh(ctx, auth)
+ if refreshErr != nil {
+ log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr)
+ } else if refreshedAuth != nil {
+ auth = refreshedAuth
+ // Persist the refreshed auth to file so subsequent requests use it
+ if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
+ log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
+ }
+ accessToken, profileArn = kiroCredentials(auth)
+ log.Infof("kiro: token refreshed successfully before stream request")
+ }
+ }
+ }
+
+ // Check for pure web_search request
+ // Route to MCP endpoint instead of normal Kiro API
+ if kiroclaude.HasWebSearchTool(req.Payload) {
+ log.Infof("kiro: detected pure web_search request, routing to MCP endpoint")
+ return e.handleWebSearchStream(ctx, auth, req, opts, accessToken, profileArn)
+ }
+
+ reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
+ defer reporter.trackFailure(ctx, &err)
+
+ from := opts.SourceFormat
+ to := sdktranslator.FromString("kiro")
+ body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
+
+ kiroModelID := e.mapModelToKiro(req.Model)
+
+ // Determine agentic mode and effective profile ARN using helper functions
+ isAgentic, isChatOnly := determineAgenticMode(req.Model)
+ effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
+
+ // Execute stream with retry on 401/403 and 429 (quota exhausted)
+ // Note: currentOrigin and kiroPayload are built inside executeStreamWithRetry for each endpoint
+ return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
+}
+
+// executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors.
+// Supports automatic fallback between endpoints with different quotas:
+// - Amazon Q endpoint (CLI origin) uses Amazon Q Developer quota
+// - CodeWhisperer endpoint (AI_EDITOR origin) uses Kiro IDE quota
+// Also supports multi-endpoint fallback similar to Antigravity implementation.
+// tokenKey is used for rate limiting and cooldown tracking.
+func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool, tokenKey string) (<-chan cliproxyexecutor.StreamChunk, error) {
+ maxRetries := 2 // Allow retries for token refresh + endpoint fallback
+ rateLimiter := kiroauth.GetGlobalRateLimiter()
+ cooldownMgr := kiroauth.GetGlobalCooldownManager()
+ endpointConfigs := getKiroEndpointConfigs(auth)
+ var last429Err error
+
+ for endpointIdx := 0; endpointIdx < len(endpointConfigs); endpointIdx++ {
+ endpointConfig := endpointConfigs[endpointIdx]
+ url := endpointConfig.URL
+ // Use this endpoint's compatible Origin (critical for avoiding 403 errors)
+ currentOrigin = endpointConfig.Origin
+
+ // Rebuild payload with the correct origin for this endpoint
+ // Each endpoint requires its matching Origin value in the request body
+ kiroPayload, thinkingEnabled := buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
+
+ log.Debugf("kiro: stream trying endpoint %d/%d: %s (Name: %s, Origin: %s)",
+ endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin)
+
+ for attempt := 0; attempt <= maxRetries; attempt++ {
+ // Apply human-like delay before first streaming request (not on retries)
+ // This mimics natural user behavior patterns
+ // Note: Delay is NOT applied during streaming response - only before initial request
+ if attempt == 0 && endpointIdx == 0 {
+ kiroauth.ApplyHumanLikeDelay()
+ }
+
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload))
+ if err != nil {
+ return nil, err
+ }
+
+ httpReq.Header.Set("Content-Type", kiroContentType)
+ httpReq.Header.Set("Accept", kiroAcceptStream)
+ // Only set X-Amz-Target if specified (Q endpoint doesn't require it)
+ if endpointConfig.AmzTarget != "" {
+ httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget)
+ }
+ // Kiro-specific headers
+ httpReq.Header.Set("x-amzn-kiro-agent-mode", kiroIDEAgentModeVibe)
+ httpReq.Header.Set("x-amzn-codewhisperer-optout", "true")
+
+ // Apply dynamic fingerprint-based headers
+ applyDynamicFingerprint(httpReq, auth)
+
+ httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
+ httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
+
+ // Bearer token authentication for all auth types (Builder ID, IDC, social, etc.)
+ httpReq.Header.Set("Authorization", "Bearer "+accessToken)
+
+ var attrs map[string]string
+ if auth != nil {
+ attrs = auth.Attributes
+ }
+ util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
+
+ var authID, authLabel, authType, authValue string
+ if auth != nil {
+ authID = auth.ID
+ authLabel = auth.Label
+ authType, authValue = auth.AccountInfo()
+ }
+ recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
+ URL: url,
+ Method: http.MethodPost,
+ Headers: httpReq.Header.Clone(),
+ Body: kiroPayload,
+ Provider: e.Identifier(),
+ AuthID: authID,
+ AuthLabel: authLabel,
+ AuthType: authType,
+ AuthValue: authValue,
+ })
+
+ httpClient := newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 0)
+ httpResp, err := httpClient.Do(httpReq)
+ if err != nil {
+ recordAPIResponseError(ctx, e.cfg, err)
+
+ // Enhanced socket retry for streaming: Check if error is retryable (network timeout, connection reset, etc.)
+ retryCfg := defaultRetryConfig()
+ if isRetryableError(err) && attempt < retryCfg.MaxRetries {
+ delay := calculateRetryDelay(attempt, retryCfg)
+ logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("stream socket error: %v", err), delay, endpointConfig.Name)
+ time.Sleep(delay)
+ continue
+ }
+
+ return nil, err
+ }
+ recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
+
+ // Handle 429 errors (quota exhausted) - try next endpoint
+ // Each endpoint has its own quota pool, so we can try different endpoints
+ if httpResp.StatusCode == 429 {
+ respBody, _ := io.ReadAll(httpResp.Body)
+ _ = httpResp.Body.Close()
+ appendAPIResponseChunk(ctx, e.cfg, respBody)
+
+ // Record failure and set cooldown for 429
+ rateLimiter.MarkTokenFailed(tokenKey)
+ cooldownDuration := kiroauth.CalculateCooldownFor429(attempt)
+ cooldownMgr.SetCooldown(tokenKey, cooldownDuration, kiroauth.CooldownReason429)
+ log.Warnf("kiro: stream rate limit hit (429), token %s set to cooldown for %v", tokenKey, cooldownDuration)
+
+ // Preserve last 429 so callers can correctly backoff when all endpoints are exhausted
+ last429Err = statusErr{code: httpResp.StatusCode, msg: string(respBody)}
+
+ log.Warnf("kiro: stream %s endpoint quota exhausted (429), will try next endpoint, body: %s",
+ endpointConfig.Name, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody))
+
+ // Break inner retry loop to try next endpoint (which has different quota)
+ break
+ }
+
+ // Handle 5xx server errors with exponential backoff retry
+ // Enhanced: Use retryConfig for consistent retry behavior
+ if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 {
+ respBody, _ := io.ReadAll(httpResp.Body)
+ _ = httpResp.Body.Close()
+ appendAPIResponseChunk(ctx, e.cfg, respBody)
+
+ retryCfg := defaultRetryConfig()
+ // Check if this specific 5xx code is retryable (502, 503, 504)
+ if isRetryableHTTPStatus(httpResp.StatusCode) && attempt < retryCfg.MaxRetries {
+ delay := calculateRetryDelay(attempt, retryCfg)
+ logRetryAttempt(attempt, retryCfg.MaxRetries, fmt.Sprintf("stream HTTP %d", httpResp.StatusCode), delay, endpointConfig.Name)
+ time.Sleep(delay)
+ continue
+ } else if attempt < maxRetries {
+ // Fallback for other 5xx errors (500, 501, etc.)
+ backoff := time.Duration(1< 30*time.Second {
+ backoff = 30 * time.Second
+ }
+ log.Warnf("kiro: stream server error %d, retrying in %v (attempt %d/%d)", httpResp.StatusCode, backoff, attempt+1, maxRetries)
+ time.Sleep(backoff)
+ continue
+ }
+ log.Errorf("kiro: stream server error %d after %d retries", httpResp.StatusCode, maxRetries)
+ return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
+ }
+
+ // Handle 400 errors - Credential/Validation issues
+ // Do NOT switch endpoints - return error immediately
+ if httpResp.StatusCode == 400 {
+ respBody, _ := io.ReadAll(httpResp.Body)
+ _ = httpResp.Body.Close()
+ appendAPIResponseChunk(ctx, e.cfg, respBody)
+
+ log.Warnf("kiro: received 400 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody))
+
+ // 400 errors indicate request validation issues - return immediately without retry
+ return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
+ }
+
+ // Handle 401 errors with token refresh and retry
+ // 401 = Unauthorized (token expired/invalid) - refresh token
+ if httpResp.StatusCode == 401 {
+ respBody, _ := io.ReadAll(httpResp.Body)
+ _ = httpResp.Body.Close()
+ appendAPIResponseChunk(ctx, e.cfg, respBody)
+
+ log.Warnf("kiro: stream received 401 error, attempting token refresh")
+ refreshedAuth, refreshErr := e.Refresh(ctx, auth)
+ if refreshErr != nil {
+ log.Errorf("kiro: token refresh failed: %v", refreshErr)
+ return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
+ }
+
+ if refreshedAuth != nil {
+ auth = refreshedAuth
+ // Persist the refreshed auth to file so subsequent requests use it
+ if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
+ log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
+ // Continue anyway - the token is valid for this request
+ }
+ accessToken, profileArn = kiroCredentials(auth)
+ // Rebuild payload with new profile ARN if changed
+ kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
+ if attempt < maxRetries {
+ log.Infof("kiro: token refreshed successfully, retrying stream request (attempt %d/%d)", attempt+1, maxRetries+1)
+ continue
+ }
+ log.Infof("kiro: token refreshed successfully, no retries remaining")
+ }
+
+ log.Warnf("kiro stream error, status: 401, body: %s", string(respBody))
+ return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
+ }
+
+ // Handle 402 errors - Monthly Limit Reached
+ if httpResp.StatusCode == 402 {
+ respBody, _ := io.ReadAll(httpResp.Body)
+ _ = httpResp.Body.Close()
+ appendAPIResponseChunk(ctx, e.cfg, respBody)
+
+ log.Warnf("kiro: stream received 402 (monthly limit). Upstream body: %s", string(respBody))
+
+ // Return upstream error body directly
+ return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
+ }
+
+ // Handle 403 errors - Access Denied / Token Expired
+ // Do NOT switch endpoints for 403 errors
+ if httpResp.StatusCode == 403 {
+ respBody, _ := io.ReadAll(httpResp.Body)
+ _ = httpResp.Body.Close()
+ appendAPIResponseChunk(ctx, e.cfg, respBody)
+
+ // Log the 403 error details for debugging
+ log.Warnf("kiro: stream received 403 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, string(respBody))
+
+ respBodyStr := string(respBody)
+
+ // Check for SUSPENDED status - return immediately without retry
+ if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") {
+ // Set long cooldown for suspended accounts
+ rateLimiter.CheckAndMarkSuspended(tokenKey, respBodyStr)
+ cooldownMgr.SetCooldown(tokenKey, kiroauth.LongCooldown, kiroauth.CooldownReasonSuspended)
+ log.Errorf("kiro: stream account is suspended, token %s set to cooldown for %v", tokenKey, kiroauth.LongCooldown)
+ return nil, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)}
+ }
+
+ // Check if this looks like a token-related 403 (some APIs return 403 for expired tokens)
+ isTokenRelated := strings.Contains(respBodyStr, "token") ||
+ strings.Contains(respBodyStr, "expired") ||
+ strings.Contains(respBodyStr, "invalid") ||
+ strings.Contains(respBodyStr, "unauthorized")
+
+ if isTokenRelated && attempt < maxRetries {
+ log.Warnf("kiro: 403 appears token-related, attempting token refresh")
+ refreshedAuth, refreshErr := e.Refresh(ctx, auth)
+ if refreshErr != nil {
+ log.Errorf("kiro: token refresh failed: %v", refreshErr)
+ // Token refresh failed - return error immediately
+ return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
+ }
+ if refreshedAuth != nil {
+ auth = refreshedAuth
+ // Persist the refreshed auth to file so subsequent requests use it
+ if persistErr := e.persistRefreshedAuth(auth); persistErr != nil {
+ log.Warnf("kiro: failed to persist refreshed auth: %v", persistErr)
+ // Continue anyway - the token is valid for this request
+ }
+ accessToken, profileArn = kiroCredentials(auth)
+ kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from, opts.Headers)
+ log.Infof("kiro: token refreshed for 403, retrying stream request")
+ continue
+ }
+ }
+
+ // For non-token 403 or after max retries, return error immediately
+ // Do NOT switch endpoints for 403 errors
+ log.Warnf("kiro: 403 error, returning immediately (no endpoint switch)")
+ return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)}
+ }
+
+ if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
+ b, _ := io.ReadAll(httpResp.Body)
+ appendAPIResponseChunk(ctx, e.cfg, b)
+ log.Debugf("kiro stream error, status: %d, body: %s", httpResp.StatusCode, string(b))
+ if errClose := httpResp.Body.Close(); errClose != nil {
+ log.Errorf("response body close error: %v", errClose)
+ }
+ return nil, statusErr{code: httpResp.StatusCode, msg: string(b)}
+ }
+
+ out := make(chan cliproxyexecutor.StreamChunk)
+
+ // Record success immediately since connection was established successfully
+ // Streaming errors will be handled separately
+ rateLimiter.MarkTokenSuccess(tokenKey)
+ log.Debugf("kiro: stream request successful, token %s marked as success", tokenKey)
+
+ go func(resp *http.Response, thinkingEnabled bool) {
+ defer close(out)
+ defer func() {
+ if r := recover(); r != nil {
+ log.Errorf("kiro: panic in stream handler: %v", r)
+ out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("internal error: %v", r)}
+ }
+ }()
+ defer func() {
+ if errClose := resp.Body.Close(); errClose != nil {
+ log.Errorf("response body close error: %v", errClose)
+ }
+ }()
+
+ // Kiro API always returns tags regardless of request parameters
+ // So we always enable thinking parsing for Kiro responses
+ log.Debugf("kiro: stream thinkingEnabled = %v (always true for Kiro)", thinkingEnabled)
+
+ e.streamToChannel(ctx, resp.Body, out, from, payloadRequestedModel(opts, req.Model), opts.OriginalRequest, body, reporter, thinkingEnabled)
+ }(httpResp, thinkingEnabled)
+
+ return out, nil
+ }
+ // Inner retry loop exhausted for this endpoint, try next endpoint
+ // Note: This code is unreachable because all paths in the inner loop
+ // either return or continue. Kept as comment for documentation.
+ }
+
+ // All endpoints exhausted
+ if last429Err != nil {
+ return nil, last429Err
+ }
+ return nil, fmt.Errorf("kiro: stream all endpoints exhausted")
+}
+
+// kiroCredentials extracts access token and profile ARN from auth.
+func kiroCredentials(auth *cliproxyauth.Auth) (accessToken, profileArn string) {
+ if auth == nil {
+ return "", ""
+ }
+
+ // Try Metadata first (wrapper format)
+ if auth.Metadata != nil {
+ if token, ok := auth.Metadata["access_token"].(string); ok {
+ accessToken = token
+ }
+ if arn, ok := auth.Metadata["profile_arn"].(string); ok {
+ profileArn = arn
+ }
+ }
+
+ // Try Attributes
+ if accessToken == "" && auth.Attributes != nil {
+ accessToken = auth.Attributes["access_token"]
+ profileArn = auth.Attributes["profile_arn"]
+ }
+
+ // Try direct fields from flat JSON format (new AWS Builder ID format)
+ if accessToken == "" && auth.Metadata != nil {
+ if token, ok := auth.Metadata["accessToken"].(string); ok {
+ accessToken = token
+ }
+ if arn, ok := auth.Metadata["profileArn"].(string); ok {
+ profileArn = arn
+ }
+ }
+
+ return accessToken, profileArn
+}
+
+// findRealThinkingEndTag finds the real end tag, skipping false positives.
+// Returns -1 if no real end tag is found.
+//
+// Real tags from Kiro API have specific characteristics:
+// - Usually preceded by newline (.\n)
+// - Usually followed by newline (\n\n)
+// - Not inside code blocks or inline code
+//
+// False positives (discussion text) have characteristics:
+// - In the middle of a sentence
+// - Preceded by discussion words like "标签", "tag", "returns"
+// - Inside code blocks or inline code
+//
+// Parameters:
+// - content: the content to search in
+// - alreadyInCodeBlock: whether we're already inside a code block from previous chunks
+// - alreadyInInlineCode: whether we're already inside inline code from previous chunks
+func findRealThinkingEndTag(content string, alreadyInCodeBlock, alreadyInInlineCode bool) int {
+ searchStart := 0
+ for {
+ endIdx := strings.Index(content[searchStart:], kirocommon.ThinkingEndTag)
+ if endIdx < 0 {
+ return -1
+ }
+ endIdx += searchStart // Adjust to absolute position
+
+ textBeforeEnd := content[:endIdx]
+ textAfterEnd := content[endIdx+len(kirocommon.ThinkingEndTag):]
+
+ // Check 1: Is it inside inline code?
+ // Count backticks in current content and add state from previous chunks
+ backtickCount := strings.Count(textBeforeEnd, "`")
+ effectiveInInlineCode := alreadyInInlineCode
+ if backtickCount%2 == 1 {
+ effectiveInInlineCode = !effectiveInInlineCode
+ }
+ if effectiveInInlineCode {
+ log.Debugf("kiro: found inside inline code at pos %d, skipping", endIdx)
+ searchStart = endIdx + len(kirocommon.ThinkingEndTag)
+ continue
+ }
+
+ // Check 2: Is it inside a code block?
+ // Count fences in current content and add state from previous chunks
+ fenceCount := strings.Count(textBeforeEnd, "```")
+ altFenceCount := strings.Count(textBeforeEnd, "~~~")
+ effectiveInCodeBlock := alreadyInCodeBlock
+ if fenceCount%2 == 1 || altFenceCount%2 == 1 {
+ effectiveInCodeBlock = !effectiveInCodeBlock
+ }
+ if effectiveInCodeBlock {
+ log.Debugf("kiro: found inside code block at pos %d, skipping", endIdx)
+ searchStart = endIdx + len(kirocommon.ThinkingEndTag)
+ continue
+ }
+
+ // Check 3: Real tags are usually preceded by newline or at start
+ // and followed by newline or at end. Check the format.
+ charBeforeTag := byte(0)
+ if endIdx > 0 {
+ charBeforeTag = content[endIdx-1]
+ }
+ charAfterTag := byte(0)
+ if len(textAfterEnd) > 0 {
+ charAfterTag = textAfterEnd[0]
+ }
+
+ // Real end tag format: preceded by newline OR end of sentence (. ! ?)
+ // and followed by newline OR end of content
+ isPrecededByNewlineOrSentenceEnd := charBeforeTag == '\n' || charBeforeTag == '.' ||
+ charBeforeTag == '!' || charBeforeTag == '?' || charBeforeTag == 0
+ isFollowedByNewlineOrEnd := charAfterTag == '\n' || charAfterTag == 0
+
+ // If the tag has proper formatting (newline before/after), it's likely real
+ if isPrecededByNewlineOrSentenceEnd && isFollowedByNewlineOrEnd {
+ log.Debugf("kiro: found properly formatted at pos %d", endIdx)
+ return endIdx
+ }
+
+ // Check 4: Is the tag preceded by discussion keywords on the same line?
+ lastNewlineIdx := strings.LastIndex(textBeforeEnd, "\n")
+ lineBeforeTag := textBeforeEnd
+ if lastNewlineIdx >= 0 {
+ lineBeforeTag = textBeforeEnd[lastNewlineIdx+1:]
+ }
+ lineBeforeTagLower := strings.ToLower(lineBeforeTag)
+
+ // Discussion patterns - if found, this is likely discussion text
+ discussionPatterns := []string{
+ "标签", "返回", "输出", "包含", "使用", "解析", "转换", "生成", // Chinese
+ "tag", "return", "output", "contain", "use", "parse", "emit", "convert", "generate", // English
+ "", // discussing both tags together
+ "``", // explicitly in inline code
+ }
+ isDiscussion := false
+ for _, pattern := range discussionPatterns {
+ if strings.Contains(lineBeforeTagLower, pattern) {
+ isDiscussion = true
+ break
+ }
+ }
+ if isDiscussion {
+ log.Debugf("kiro: found after discussion text at pos %d, skipping", endIdx)
+ searchStart = endIdx + len(kirocommon.ThinkingEndTag)
+ continue
+ }
+
+ // Check 5: Is there text immediately after on the same line?
+ // Real end tags don't have text immediately after on the same line
+ if len(textAfterEnd) > 0 && charAfterTag != '\n' && charAfterTag != 0 {
+ // Find the next newline
+ nextNewline := strings.Index(textAfterEnd, "\n")
+ var textOnSameLine string
+ if nextNewline >= 0 {
+ textOnSameLine = textAfterEnd[:nextNewline]
+ } else {
+ textOnSameLine = textAfterEnd
+ }
+ // If there's non-whitespace text on the same line after the tag, it's discussion
+ if strings.TrimSpace(textOnSameLine) != "" {
+ log.Debugf("kiro: found with text after on same line at pos %d, skipping", endIdx)
+ searchStart = endIdx + len(kirocommon.ThinkingEndTag)
+ continue
+ }
+ }
+
+ // Check 6: Is there another tag after this ?
+ if strings.Contains(textAfterEnd, kirocommon.ThinkingStartTag) {
+ nextStartIdx := strings.Index(textAfterEnd, kirocommon.ThinkingStartTag)
+ textBeforeNextStart := textAfterEnd[:nextStartIdx]
+ nextBacktickCount := strings.Count(textBeforeNextStart, "`")
+ nextFenceCount := strings.Count(textBeforeNextStart, "```")
+ nextAltFenceCount := strings.Count(textBeforeNextStart, "~~~")
+
+ // If the next is NOT in code, then this is discussion text
+ if nextBacktickCount%2 == 0 && nextFenceCount%2 == 0 && nextAltFenceCount%2 == 0 {
+ log.Debugf("kiro: found followed by at pos %d, likely discussion text, skipping", endIdx)
+ searchStart = endIdx + len(kirocommon.ThinkingEndTag)
+ continue
+ }
+ }
+
+ // This looks like a real end tag
+ return endIdx
+ }
+}
+
+// determineAgenticMode determines if the model is an agentic or chat-only variant.
+// Returns (isAgentic, isChatOnly) based on model name suffixes.
+func determineAgenticMode(model string) (isAgentic, isChatOnly bool) {
+ isAgentic = strings.HasSuffix(model, "-agentic")
+ isChatOnly = strings.HasSuffix(model, "-chat")
+ return isAgentic, isChatOnly
+}
+
+// getEffectiveProfileArn determines if profileArn should be included based on auth method.
+// profileArn is only needed for social auth (Google OAuth), not for AWS SSO OIDC (Builder ID/IDC).
+//
+// Detection logic (matching kiro-openai-gateway):
+// 1. Check auth_method field: "builder-id" or "idc"
+// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens)
+// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature)
+func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string {
+ if auth != nil && auth.Metadata != nil {
+ // Check 1: auth_method field (from CLIProxyAPI tokens)
+ if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") {
+ return "" // AWS SSO OIDC - don't include profileArn
+ }
+ // Check 2: auth_type field (from kiro-cli tokens)
+ if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" {
+ return "" // AWS SSO OIDC - don't include profileArn
+ }
+ // Check 3: client_id + client_secret presence (AWS SSO OIDC signature)
+ _, hasClientID := auth.Metadata["client_id"].(string)
+ _, hasClientSecret := auth.Metadata["client_secret"].(string)
+ if hasClientID && hasClientSecret {
+ return "" // AWS SSO OIDC - don't include profileArn
+ }
+ }
+ return profileArn
+}
+
+// getEffectiveProfileArnWithWarning determines if profileArn should be included based on auth method,
+// and logs a warning if profileArn is missing for non-builder-id auth.
+// This consolidates the auth_method check that was previously done separately.
+//
+// AWS SSO OIDC (Builder ID/IDC) users don't need profileArn - sending it causes 403 errors.
+// Only Kiro Desktop (social auth like Google/GitHub) users need profileArn.
+//
+// Detection logic (matching kiro-openai-gateway):
+// 1. Check auth_method field: "builder-id" or "idc"
+// 2. Check auth_type field: "aws_sso_oidc" (from kiro-cli tokens)
+// 3. Check for client_id + client_secret presence (AWS SSO OIDC signature)
+func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string {
+ if auth != nil && auth.Metadata != nil {
+ // Check 1: auth_method field (from CLIProxyAPI tokens)
+ if authMethod, ok := auth.Metadata["auth_method"].(string); ok && (authMethod == "builder-id" || authMethod == "idc") {
+ return "" // AWS SSO OIDC - don't include profileArn
+ }
+ // Check 2: auth_type field (from kiro-cli tokens)
+ if authType, ok := auth.Metadata["auth_type"].(string); ok && authType == "aws_sso_oidc" {
+ return "" // AWS SSO OIDC - don't include profileArn
+ }
+ // Check 3: client_id + client_secret presence (AWS SSO OIDC signature, like kiro-openai-gateway)
+ _, hasClientID := auth.Metadata["client_id"].(string)
+ _, hasClientSecret := auth.Metadata["client_secret"].(string)
+ if hasClientID && hasClientSecret {
+ return "" // AWS SSO OIDC - don't include profileArn
+ }
+ }
+ // For social auth (Kiro Desktop), profileArn is required
+ if profileArn == "" {
+ log.Warnf("kiro: profile ARN not found in auth, API calls may fail")
+ }
+ return profileArn
+}
+
+// mapModelToKiro maps external model names to Kiro model IDs.
+// Supports both Kiro and Amazon Q prefixes since they use the same API.
+// Agentic variants (-agentic suffix) map to the same backend model IDs.
+func (e *KiroExecutor) mapModelToKiro(model string) string {
+ modelMap := map[string]string{
+ // Amazon Q format (amazonq- prefix) - same API as Kiro
+ "amazonq-auto": "auto",
+ "amazonq-claude-opus-4-6": "claude-opus-4.6",
+ "amazonq-claude-sonnet-4-6": "claude-sonnet-4.6",
+ "amazonq-claude-opus-4-5": "claude-opus-4.5",
+ "amazonq-claude-sonnet-4-5": "claude-sonnet-4.5",
+ "amazonq-claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
+ "amazonq-claude-sonnet-4": "claude-sonnet-4",
+ "amazonq-claude-sonnet-4-20250514": "claude-sonnet-4",
+ "amazonq-claude-haiku-4-5": "claude-haiku-4.5",
+ // Kiro format (kiro- prefix) - valid model names that should be preserved
+ "kiro-claude-opus-4-6": "claude-opus-4.6",
+ "kiro-claude-sonnet-4-6": "claude-sonnet-4.6",
+ "kiro-claude-opus-4-5": "claude-opus-4.5",
+ "kiro-claude-sonnet-4-5": "claude-sonnet-4.5",
+ "kiro-claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
+ "kiro-claude-sonnet-4": "claude-sonnet-4",
+ "kiro-claude-sonnet-4-20250514": "claude-sonnet-4",
+ "kiro-claude-haiku-4-5": "claude-haiku-4.5",
+ "kiro-auto": "auto",
+ // Native format (no prefix) - used by Kiro IDE directly
+ "claude-opus-4-6": "claude-opus-4.6",
+ "claude-opus-4.6": "claude-opus-4.6",
+ "claude-sonnet-4-6": "claude-sonnet-4.6",
+ "claude-sonnet-4.6": "claude-sonnet-4.6",
+ "claude-opus-4-5": "claude-opus-4.5",
+ "claude-opus-4.5": "claude-opus-4.5",
+ "claude-haiku-4-5": "claude-haiku-4.5",
+ "claude-haiku-4.5": "claude-haiku-4.5",
+ "claude-sonnet-4-5": "claude-sonnet-4.5",
+ "claude-sonnet-4-5-20250929": "claude-sonnet-4.5",
+ "claude-sonnet-4.5": "claude-sonnet-4.5",
+ "claude-sonnet-4": "claude-sonnet-4",
+ "claude-sonnet-4-20250514": "claude-sonnet-4",
+ "auto": "auto",
+ // Agentic variants (same backend model IDs, but with special system prompt)
+ "claude-opus-4.6-agentic": "claude-opus-4.6",
+ "claude-sonnet-4.6-agentic": "claude-sonnet-4.6",
+ "claude-opus-4.5-agentic": "claude-opus-4.5",
+ "claude-sonnet-4.5-agentic": "claude-sonnet-4.5",
+ "claude-sonnet-4-agentic": "claude-sonnet-4",
+ "claude-haiku-4.5-agentic": "claude-haiku-4.5",
+ "kiro-claude-opus-4-6-agentic": "claude-opus-4.6",
+ "kiro-claude-sonnet-4-6-agentic": "claude-sonnet-4.6",
+ "kiro-claude-opus-4-5-agentic": "claude-opus-4.5",
+ "kiro-claude-sonnet-4-5-agentic": "claude-sonnet-4.5",
+ "kiro-claude-sonnet-4-agentic": "claude-sonnet-4",
+ "kiro-claude-haiku-4-5-agentic": "claude-haiku-4.5",
+ }
+ if kiroID, ok := modelMap[model]; ok {
+ return kiroID
+ }
+
+ // Smart fallback: try to infer model type from name patterns
+ modelLower := strings.ToLower(model)
+
+ // Check for Haiku variants
+ if strings.Contains(modelLower, "haiku") {
+ log.Debugf("kiro: unknown Haiku model '%s', mapping to claude-haiku-4.5", model)
+ return "claude-haiku-4.5"
+ }
+
+ // Check for Sonnet variants
+ if strings.Contains(modelLower, "sonnet") {
+ // Check for specific version patterns
+ if strings.Contains(modelLower, "3-7") || strings.Contains(modelLower, "3.7") {
+ log.Debugf("kiro: unknown Sonnet 3.7 model '%s', mapping to claude-3-7-sonnet-20250219", model)
+ return "claude-3-7-sonnet-20250219"
+ }
+ if strings.Contains(modelLower, "4-6") || strings.Contains(modelLower, "4.6") {
+ log.Debugf("kiro: unknown Sonnet 4.6 model '%s', mapping to claude-sonnet-4.6", model)
+ return "claude-sonnet-4.6"
+ }
+ if strings.Contains(modelLower, "4-5") || strings.Contains(modelLower, "4.5") {
+ log.Debugf("kiro: unknown Sonnet 4.5 model '%s', mapping to claude-sonnet-4.5", model)
+ return "claude-sonnet-4.5"
+ }
+ // Default to Sonnet 4
+ log.Debugf("kiro: unknown Sonnet model '%s', mapping to claude-sonnet-4", model)
+ return "claude-sonnet-4"
+ }
+
+ // Check for Opus variants
+ if strings.Contains(modelLower, "opus") {
+ if strings.Contains(modelLower, "4-6") || strings.Contains(modelLower, "4.6") {
+ log.Debugf("kiro: unknown Opus 4.6 model '%s', mapping to claude-opus-4.6", model)
+ return "claude-opus-4.6"
+ }
+ log.Debugf("kiro: unknown Opus model '%s', mapping to claude-opus-4.5", model)
+ return "claude-opus-4.5"
+ }
+
+ // Final fallback to Sonnet 4.5 (most commonly used model)
+ log.Warnf("kiro: unknown model '%s', falling back to claude-sonnet-4.5", model)
+ return "claude-sonnet-4.5"
+}
+
+// EventStreamError represents an Event Stream processing error
+type EventStreamError struct {
+ Type string // "fatal", "malformed"
+ Message string
+ Cause error
+}
+
+func (e *EventStreamError) Error() string {
+ if e.Cause != nil {
+ return fmt.Sprintf("event stream %s: %s: %v", e.Type, e.Message, e.Cause)
+ }
+ return fmt.Sprintf("event stream %s: %s", e.Type, e.Message)
+}
+
+// eventStreamMessage represents a parsed AWS Event Stream message
+type eventStreamMessage struct {
+ EventType string // Event type from headers (e.g., "assistantResponseEvent")
+ Payload []byte // JSON payload of the message
+}
+
+// NOTE: Request building functions moved to internal/translator/kiro/claude/kiro_claude_request.go
+// The executor now uses kiroclaude.BuildKiroPayload() instead
+
+// parseEventStream parses AWS Event Stream binary format.
+// Extracts text content, tool uses, and stop_reason from the response.
+// Supports embedded [Called ...] tool calls and input buffering for toolUseEvent.
+// Returns: content, toolUses, usageInfo, stopReason, error
+func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroclaude.KiroToolUse, usage.Detail, string, error) {
+ var content strings.Builder
+ var toolUses []kiroclaude.KiroToolUse
+ var usageInfo usage.Detail
+ var stopReason string // Extracted from upstream response
+ reader := bufio.NewReader(body)
+
+ // Tool use state tracking for input buffering and deduplication
+ processedIDs := make(map[string]bool)
+ var currentToolUse *kiroclaude.ToolUseState
+
+ // Upstream usage tracking - Kiro API returns credit usage and context percentage
+ var upstreamContextPercentage float64 // Context usage percentage from upstream (e.g., 78.56)
+
+ for {
+ msg, eventErr := e.readEventStreamMessage(reader)
+ if eventErr != nil {
+ log.Errorf("kiro: parseEventStream error: %v", eventErr)
+ return content.String(), toolUses, usageInfo, stopReason, eventErr
+ }
+ if msg == nil {
+ // Normal end of stream (EOF)
+ break
+ }
+
+ eventType := msg.EventType
+ payload := msg.Payload
+ if len(payload) == 0 {
+ continue
+ }
+
+ var event map[string]interface{}
+ if err := json.Unmarshal(payload, &event); err != nil {
+ log.Debugf("kiro: skipping malformed event: %v", err)
+ continue
+ }
+
+ // Check for error/exception events in the payload (Kiro API may return errors with HTTP 200)
+ // These can appear as top-level fields or nested within the event
+ if errType, hasErrType := event["_type"].(string); hasErrType {
+ // AWS-style error: {"_type": "com.amazon.aws.codewhisperer#ValidationException", "message": "..."}
+ errMsg := ""
+ if msg, ok := event["message"].(string); ok {
+ errMsg = msg
+ }
+ log.Errorf("kiro: received AWS error in event stream: type=%s, message=%s", errType, errMsg)
+ return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s - %s", errType, errMsg)
+ }
+ if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") {
+ // Generic error event
+ errMsg := ""
+ if msg, ok := event["message"].(string); ok {
+ errMsg = msg
+ } else if errObj, ok := event["error"].(map[string]interface{}); ok {
+ if msg, ok := errObj["message"].(string); ok {
+ errMsg = msg
+ }
+ }
+ log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg)
+ return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s", errMsg)
+ }
+
+ // Extract stop_reason from various event formats
+ // Kiro/Amazon Q API may include stop_reason in different locations
+ if sr := kirocommon.GetString(event, "stop_reason"); sr != "" {
+ stopReason = sr
+ log.Debugf("kiro: parseEventStream found stop_reason (top-level): %s", stopReason)
+ }
+ if sr := kirocommon.GetString(event, "stopReason"); sr != "" {
+ stopReason = sr
+ log.Debugf("kiro: parseEventStream found stopReason (top-level): %s", stopReason)
+ }
+
+ // Handle different event types
+ switch eventType {
+ case "followupPromptEvent":
+ // Filter out followupPrompt events - these are UI suggestions, not content
+ log.Debugf("kiro: parseEventStream ignoring followupPrompt event")
+ continue
+
+ case "assistantResponseEvent":
+ if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok {
+ if contentText, ok := assistantResp["content"].(string); ok {
+ content.WriteString(contentText)
+ }
+ // Extract stop_reason from assistantResponseEvent
+ if sr := kirocommon.GetString(assistantResp, "stop_reason"); sr != "" {
+ stopReason = sr
+ log.Debugf("kiro: parseEventStream found stop_reason in assistantResponseEvent: %s", stopReason)
+ }
+ if sr := kirocommon.GetString(assistantResp, "stopReason"); sr != "" {
+ stopReason = sr
+ log.Debugf("kiro: parseEventStream found stopReason in assistantResponseEvent: %s", stopReason)
+ }
+ // Extract tool uses from response
+ if toolUsesRaw, ok := assistantResp["toolUses"].([]interface{}); ok {
+ for _, tuRaw := range toolUsesRaw {
+ if tu, ok := tuRaw.(map[string]interface{}); ok {
+ toolUseID := kirocommon.GetStringValue(tu, "toolUseId")
+ // Check for duplicate
+ if processedIDs[toolUseID] {
+ log.Debugf("kiro: skipping duplicate tool use from assistantResponse: %s", toolUseID)
+ continue
+ }
+ processedIDs[toolUseID] = true
+
+ toolUse := kiroclaude.KiroToolUse{
+ ToolUseID: toolUseID,
+ Name: kirocommon.GetStringValue(tu, "name"),
+ }
+ if input, ok := tu["input"].(map[string]interface{}); ok {
+ toolUse.Input = input
+ }
+ toolUses = append(toolUses, toolUse)
+ }
+ }
+ }
+ }
+ // Also try direct format
+ if contentText, ok := event["content"].(string); ok {
+ content.WriteString(contentText)
+ }
+ // Direct tool uses
+ if toolUsesRaw, ok := event["toolUses"].([]interface{}); ok {
+ for _, tuRaw := range toolUsesRaw {
+ if tu, ok := tuRaw.(map[string]interface{}); ok {
+ toolUseID := kirocommon.GetStringValue(tu, "toolUseId")
+ // Check for duplicate
+ if processedIDs[toolUseID] {
+ log.Debugf("kiro: skipping duplicate direct tool use: %s", toolUseID)
+ continue
+ }
+ processedIDs[toolUseID] = true
+
+ toolUse := kiroclaude.KiroToolUse{
+ ToolUseID: toolUseID,
+ Name: kirocommon.GetStringValue(tu, "name"),
+ }
+ if input, ok := tu["input"].(map[string]interface{}); ok {
+ toolUse.Input = input
+ }
+ toolUses = append(toolUses, toolUse)
+ }
+ }
+ }
+
+ case "toolUseEvent":
+ // Handle dedicated tool use events with input buffering
+ completedToolUses, newState := kiroclaude.ProcessToolUseEvent(event, currentToolUse, processedIDs)
+ currentToolUse = newState
+ toolUses = append(toolUses, completedToolUses...)
+
+ case "supplementaryWebLinksEvent":
+ if inputTokens, ok := event["inputTokens"].(float64); ok {
+ usageInfo.InputTokens = int64(inputTokens)
+ }
+ if outputTokens, ok := event["outputTokens"].(float64); ok {
+ usageInfo.OutputTokens = int64(outputTokens)
+ }
+
+ case "messageStopEvent", "message_stop":
+ // Handle message stop events which may contain stop_reason
+ if sr := kirocommon.GetString(event, "stop_reason"); sr != "" {
+ stopReason = sr
+ log.Debugf("kiro: parseEventStream found stop_reason in messageStopEvent: %s", stopReason)
+ }
+ if sr := kirocommon.GetString(event, "stopReason"); sr != "" {
+ stopReason = sr
+ log.Debugf("kiro: parseEventStream found stopReason in messageStopEvent: %s", stopReason)
+ }
+
+ case "messageMetadataEvent", "metadataEvent":
+ // Handle message metadata events which contain token counts
+ // Official format: { tokenUsage: { outputTokens, totalTokens, uncachedInputTokens, cacheReadInputTokens, cacheWriteInputTokens, contextUsagePercentage } }
+ var metadata map[string]interface{}
+ if m, ok := event["messageMetadataEvent"].(map[string]interface{}); ok {
+ metadata = m
+ } else if m, ok := event["metadataEvent"].(map[string]interface{}); ok {
+ metadata = m
+ } else {
+ metadata = event // event itself might be the metadata
+ }
+
+ // Check for nested tokenUsage object (official format)
+ if tokenUsage, ok := metadata["tokenUsage"].(map[string]interface{}); ok {
+ // outputTokens - precise output token count
+ if outputTokens, ok := tokenUsage["outputTokens"].(float64); ok {
+ usageInfo.OutputTokens = int64(outputTokens)
+ log.Infof("kiro: parseEventStream found precise outputTokens in tokenUsage: %d", usageInfo.OutputTokens)
+ }
+ // totalTokens - precise total token count
+ if totalTokens, ok := tokenUsage["totalTokens"].(float64); ok {
+ usageInfo.TotalTokens = int64(totalTokens)
+ log.Infof("kiro: parseEventStream found precise totalTokens in tokenUsage: %d", usageInfo.TotalTokens)
+ }
+ // uncachedInputTokens - input tokens not from cache
+ if uncachedInputTokens, ok := tokenUsage["uncachedInputTokens"].(float64); ok {
+ usageInfo.InputTokens = int64(uncachedInputTokens)
+ log.Infof("kiro: parseEventStream found uncachedInputTokens in tokenUsage: %d", usageInfo.InputTokens)
+ }
+ // cacheReadInputTokens - tokens read from cache
+ if cacheReadTokens, ok := tokenUsage["cacheReadInputTokens"].(float64); ok {
+ // Add to input tokens if we have uncached tokens, otherwise use as input
+ if usageInfo.InputTokens > 0 {
+ usageInfo.InputTokens += int64(cacheReadTokens)
+ } else {
+ usageInfo.InputTokens = int64(cacheReadTokens)
+ }
+ log.Debugf("kiro: parseEventStream found cacheReadInputTokens in tokenUsage: %d", int64(cacheReadTokens))
+ }
+ // contextUsagePercentage - can be used as fallback for input token estimation
+ if ctxPct, ok := tokenUsage["contextUsagePercentage"].(float64); ok {
+ upstreamContextPercentage = ctxPct
+ log.Debugf("kiro: parseEventStream found contextUsagePercentage in tokenUsage: %.2f%%", ctxPct)
+ }
+ }
+
+ // Fallback: check for direct fields in metadata (legacy format)
+ if usageInfo.InputTokens == 0 {
+ if inputTokens, ok := metadata["inputTokens"].(float64); ok {
+ usageInfo.InputTokens = int64(inputTokens)
+ log.Debugf("kiro: parseEventStream found inputTokens in messageMetadataEvent: %d", usageInfo.InputTokens)
+ }
+ }
+ if usageInfo.OutputTokens == 0 {
+ if outputTokens, ok := metadata["outputTokens"].(float64); ok {
+ usageInfo.OutputTokens = int64(outputTokens)
+ log.Debugf("kiro: parseEventStream found outputTokens in messageMetadataEvent: %d", usageInfo.OutputTokens)
+ }
+ }
+ if usageInfo.TotalTokens == 0 {
+ if totalTokens, ok := metadata["totalTokens"].(float64); ok {
+ usageInfo.TotalTokens = int64(totalTokens)
+ log.Debugf("kiro: parseEventStream found totalTokens in messageMetadataEvent: %d", usageInfo.TotalTokens)
+ }
+ }
+
+ case "usageEvent", "usage":
+ // Handle dedicated usage events
+ if inputTokens, ok := event["inputTokens"].(float64); ok {
+ usageInfo.InputTokens = int64(inputTokens)
+ log.Debugf("kiro: parseEventStream found inputTokens in usageEvent: %d", usageInfo.InputTokens)
+ }
+ if outputTokens, ok := event["outputTokens"].(float64); ok {
+ usageInfo.OutputTokens = int64(outputTokens)
+ log.Debugf("kiro: parseEventStream found outputTokens in usageEvent: %d", usageInfo.OutputTokens)
+ }
+ if totalTokens, ok := event["totalTokens"].(float64); ok {
+ usageInfo.TotalTokens = int64(totalTokens)
+ log.Debugf("kiro: parseEventStream found totalTokens in usageEvent: %d", usageInfo.TotalTokens)
+ }
+ // Also check nested usage object
+ if usageObj, ok := event["usage"].(map[string]interface{}); ok {
+ if inputTokens, ok := usageObj["input_tokens"].(float64); ok {
+ usageInfo.InputTokens = int64(inputTokens)
+ } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok {
+ usageInfo.InputTokens = int64(inputTokens)
+ }
+ if outputTokens, ok := usageObj["output_tokens"].(float64); ok {
+ usageInfo.OutputTokens = int64(outputTokens)
+ } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok {
+ usageInfo.OutputTokens = int64(outputTokens)
+ }
+ if totalTokens, ok := usageObj["total_tokens"].(float64); ok {
+ usageInfo.TotalTokens = int64(totalTokens)
+ }
+ log.Debugf("kiro: parseEventStream found usage object: input=%d, output=%d, total=%d",
+ usageInfo.InputTokens, usageInfo.OutputTokens, usageInfo.TotalTokens)
+ }
+
+ case "metricsEvent":
+ // Handle metrics events which may contain usage data
+ if metrics, ok := event["metricsEvent"].(map[string]interface{}); ok {
+ if inputTokens, ok := metrics["inputTokens"].(float64); ok {
+ usageInfo.InputTokens = int64(inputTokens)
+ }
+ if outputTokens, ok := metrics["outputTokens"].(float64); ok {
+ usageInfo.OutputTokens = int64(outputTokens)
+ }
+ log.Debugf("kiro: parseEventStream found metricsEvent: input=%d, output=%d",
+ usageInfo.InputTokens, usageInfo.OutputTokens)
+ }
+
+ case "meteringEvent":
+ // Handle metering events from Kiro API (usage billing information)
+ // Official format: { unit: string, unitPlural: string, usage: number }
+ if metering, ok := event["meteringEvent"].(map[string]interface{}); ok {
+ unit := ""
+ if u, ok := metering["unit"].(string); ok {
+ unit = u
+ }
+ usageVal := 0.0
+ if u, ok := metering["usage"].(float64); ok {
+ usageVal = u
+ }
+ log.Infof("kiro: parseEventStream received meteringEvent: usage=%.2f %s", usageVal, unit)
+ // Store metering info for potential billing/statistics purposes
+ // Note: This is separate from token counts - it's AWS billing units
+ } else {
+ // Try direct fields
+ unit := ""
+ if u, ok := event["unit"].(string); ok {
+ unit = u
+ }
+ usageVal := 0.0
+ if u, ok := event["usage"].(float64); ok {
+ usageVal = u
+ }
+ if unit != "" || usageVal > 0 {
+ log.Infof("kiro: parseEventStream received meteringEvent (direct): usage=%.2f %s", usageVal, unit)
+ }
+ }
+
+ case "contextUsageEvent":
+ // Handle context usage events from Kiro API
+ // Format: {"contextUsageEvent": {"contextUsagePercentage": 0.53}}
+ if ctxUsage, ok := event["contextUsageEvent"].(map[string]interface{}); ok {
+ if ctxPct, ok := ctxUsage["contextUsagePercentage"].(float64); ok {
+ upstreamContextPercentage = ctxPct
+ log.Debugf("kiro: parseEventStream received contextUsageEvent: %.2f%%", ctxPct*100)
+ }
+ } else {
+ // Try direct field (fallback)
+ if ctxPct, ok := event["contextUsagePercentage"].(float64); ok {
+ upstreamContextPercentage = ctxPct
+ log.Debugf("kiro: parseEventStream received contextUsagePercentage (direct): %.2f%%", ctxPct*100)
+ }
+ }
+
+ case "error", "exception", "internalServerException", "invalidStateEvent":
+ // Handle error events from Kiro API stream
+ errMsg := ""
+ errType := eventType
+
+ // Try to extract error message from various formats
+ if msg, ok := event["message"].(string); ok {
+ errMsg = msg
+ } else if errObj, ok := event[eventType].(map[string]interface{}); ok {
+ if msg, ok := errObj["message"].(string); ok {
+ errMsg = msg
+ }
+ if t, ok := errObj["type"].(string); ok {
+ errType = t
+ }
+ } else if errObj, ok := event["error"].(map[string]interface{}); ok {
+ if msg, ok := errObj["message"].(string); ok {
+ errMsg = msg
+ }
+ if t, ok := errObj["type"].(string); ok {
+ errType = t
+ }
+ }
+
+ // Check for specific error reasons
+ if reason, ok := event["reason"].(string); ok {
+ errMsg = fmt.Sprintf("%s (reason: %s)", errMsg, reason)
+ }
+
+ log.Errorf("kiro: parseEventStream received error event: type=%s, message=%s", errType, errMsg)
+
+ // For invalidStateEvent, we may want to continue processing other events
+ if eventType == "invalidStateEvent" {
+ log.Warnf("kiro: invalidStateEvent received, continuing stream processing")
+ continue
+ }
+
+ // For other errors, return the error
+ if errMsg != "" {
+ return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error (%s): %s", errType, errMsg)
+ }
+
+ default:
+ // Check for contextUsagePercentage in any event
+ if ctxPct, ok := event["contextUsagePercentage"].(float64); ok {
+ upstreamContextPercentage = ctxPct
+ log.Debugf("kiro: parseEventStream received context usage: %.2f%%", upstreamContextPercentage)
+ }
+ // Log unknown event types for debugging (to discover new event formats)
+ log.Debugf("kiro: parseEventStream unknown event type: %s, payload: %s", eventType, string(payload))
+ }
+
+ // Check for direct token fields in any event (fallback)
+ if usageInfo.InputTokens == 0 {
+ if inputTokens, ok := event["inputTokens"].(float64); ok {
+ usageInfo.InputTokens = int64(inputTokens)
+ log.Debugf("kiro: parseEventStream found direct inputTokens: %d", usageInfo.InputTokens)
+ }
+ }
+ if usageInfo.OutputTokens == 0 {
+ if outputTokens, ok := event["outputTokens"].(float64); ok {
+ usageInfo.OutputTokens = int64(outputTokens)
+ log.Debugf("kiro: parseEventStream found direct outputTokens: %d", usageInfo.OutputTokens)
+ }
+ }
+
+ // Check for usage object in any event (OpenAI format)
+ if usageInfo.InputTokens == 0 || usageInfo.OutputTokens == 0 {
+ if usageObj, ok := event["usage"].(map[string]interface{}); ok {
+ if usageInfo.InputTokens == 0 {
+ if inputTokens, ok := usageObj["input_tokens"].(float64); ok {
+ usageInfo.InputTokens = int64(inputTokens)
+ } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok {
+ usageInfo.InputTokens = int64(inputTokens)
+ }
+ }
+ if usageInfo.OutputTokens == 0 {
+ if outputTokens, ok := usageObj["output_tokens"].(float64); ok {
+ usageInfo.OutputTokens = int64(outputTokens)
+ } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok {
+ usageInfo.OutputTokens = int64(outputTokens)
+ }
+ }
+ if usageInfo.TotalTokens == 0 {
+ if totalTokens, ok := usageObj["total_tokens"].(float64); ok {
+ usageInfo.TotalTokens = int64(totalTokens)
+ }
+ }
+ log.Debugf("kiro: parseEventStream found usage object (fallback): input=%d, output=%d, total=%d",
+ usageInfo.InputTokens, usageInfo.OutputTokens, usageInfo.TotalTokens)
+ }
+ }
+
+ // Also check nested supplementaryWebLinksEvent
+ if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok {
+ if inputTokens, ok := usageEvent["inputTokens"].(float64); ok {
+ usageInfo.InputTokens = int64(inputTokens)
+ }
+ if outputTokens, ok := usageEvent["outputTokens"].(float64); ok {
+ usageInfo.OutputTokens = int64(outputTokens)
+ }
+ }
+ }
+
+ // Parse embedded tool calls from content (e.g., [Called tool_name with args: {...}])
+ contentStr := content.String()
+ cleanedContent, embeddedToolUses := kiroclaude.ParseEmbeddedToolCalls(contentStr, processedIDs)
+ toolUses = append(toolUses, embeddedToolUses...)
+
+ // Deduplicate all tool uses
+ toolUses = kiroclaude.DeduplicateToolUses(toolUses)
+
+ // Apply fallback logic for stop_reason if not provided by upstream
+ // Priority: upstream stopReason > tool_use detection > end_turn default
+ if stopReason == "" {
+ if len(toolUses) > 0 {
+ stopReason = "tool_use"
+ log.Debugf("kiro: parseEventStream using fallback stop_reason: tool_use (detected %d tool uses)", len(toolUses))
+ } else {
+ stopReason = "end_turn"
+ log.Debugf("kiro: parseEventStream using fallback stop_reason: end_turn")
+ }
+ }
+
+ // Log warning if response was truncated due to max_tokens
+ if stopReason == "max_tokens" {
+ log.Warnf("kiro: response truncated due to max_tokens limit")
+ }
+
+ // Use contextUsagePercentage to calculate more accurate input tokens
+ // Kiro model has 200k max context, contextUsagePercentage represents the percentage used
+ // Formula: input_tokens = contextUsagePercentage * 200000 / 100
+ if upstreamContextPercentage > 0 {
+ calculatedInputTokens := int64(upstreamContextPercentage * 200000 / 100)
+ if calculatedInputTokens > 0 {
+ localEstimate := usageInfo.InputTokens
+ usageInfo.InputTokens = calculatedInputTokens
+ usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens
+ log.Infof("kiro: parseEventStream using contextUsagePercentage (%.2f%%) to calculate input tokens: %d (local estimate was: %d)",
+ upstreamContextPercentage, calculatedInputTokens, localEstimate)
+ }
+ }
+
+ return cleanedContent, toolUses, usageInfo, stopReason, nil
+}
+
+// readEventStreamMessage reads and validates a single AWS Event Stream message.
+// Returns the parsed message or a structured error for different failure modes.
+// This function implements boundary protection and detailed error classification.
+//
+// AWS Event Stream binary format:
+// - Prelude (12 bytes): total_length (4) + headers_length (4) + prelude_crc (4)
+// - Headers (variable): header entries
+// - Payload (variable): JSON data
+// - Message CRC (4 bytes): CRC32C of entire message (not validated, just skipped)
+func (e *KiroExecutor) readEventStreamMessage(reader *bufio.Reader) (*eventStreamMessage, *EventStreamError) {
+ // Read prelude (first 12 bytes: total_len + headers_len + prelude_crc)
+ prelude := make([]byte, 12)
+ _, err := io.ReadFull(reader, prelude)
+ if err == io.EOF {
+ return nil, nil // Normal end of stream
+ }
+ if err != nil {
+ return nil, &EventStreamError{
+ Type: ErrStreamFatal,
+ Message: "failed to read prelude",
+ Cause: err,
+ }
+ }
+
+ totalLength := binary.BigEndian.Uint32(prelude[0:4])
+ headersLength := binary.BigEndian.Uint32(prelude[4:8])
+ // Note: prelude[8:12] is prelude_crc - we read it but don't validate (no CRC check per requirements)
+
+ // Boundary check: minimum frame size
+ if totalLength < minEventStreamFrameSize {
+ return nil, &EventStreamError{
+ Type: ErrStreamMalformed,
+ Message: fmt.Sprintf("invalid message length: %d (minimum is %d)", totalLength, minEventStreamFrameSize),
+ }
+ }
+
+ // Boundary check: maximum message size
+ if totalLength > maxEventStreamMsgSize {
+ return nil, &EventStreamError{
+ Type: ErrStreamMalformed,
+ Message: fmt.Sprintf("message too large: %d bytes (maximum is %d)", totalLength, maxEventStreamMsgSize),
+ }
+ }
+
+ // Boundary check: headers length within message bounds
+ // Message structure: prelude(12) + headers(headersLength) + payload + message_crc(4)
+ // So: headersLength must be <= totalLength - 16 (12 for prelude + 4 for message_crc)
+ if headersLength > totalLength-16 {
+ return nil, &EventStreamError{
+ Type: ErrStreamMalformed,
+ Message: fmt.Sprintf("headers length %d exceeds message bounds (total: %d)", headersLength, totalLength),
+ }
+ }
+
+ // Read the rest of the message (total - 12 bytes already read)
+ remaining := make([]byte, totalLength-12)
+ _, err = io.ReadFull(reader, remaining)
+ if err != nil {
+ return nil, &EventStreamError{
+ Type: ErrStreamFatal,
+ Message: "failed to read message body",
+ Cause: err,
+ }
+ }
+
+ // Extract event type from headers
+ // Headers start at beginning of 'remaining', length is headersLength
+ var eventType string
+ if headersLength > 0 && headersLength <= uint32(len(remaining)) {
+ eventType = e.extractEventTypeFromBytes(remaining[:headersLength])
+ }
+
+ // Calculate payload boundaries
+ // Payload starts after headers, ends before message_crc (last 4 bytes)
+ payloadStart := headersLength
+ payloadEnd := uint32(len(remaining)) - 4 // Skip message_crc at end
+
+ // Validate payload boundaries
+ if payloadStart >= payloadEnd {
+ // No payload, return empty message
+ return &eventStreamMessage{
+ EventType: eventType,
+ Payload: nil,
+ }, nil
+ }
+
+ payload := remaining[payloadStart:payloadEnd]
+
+ return &eventStreamMessage{
+ EventType: eventType,
+ Payload: payload,
+ }, nil
+}
+
+func skipEventStreamHeaderValue(headers []byte, offset int, valueType byte) (int, bool) {
+ switch valueType {
+ case 0, 1: // bool true / bool false
+ return offset, true
+ case 2: // byte
+ if offset+1 > len(headers) {
+ return offset, false
+ }
+ return offset + 1, true
+ case 3: // short
+ if offset+2 > len(headers) {
+ return offset, false
+ }
+ return offset + 2, true
+ case 4: // int
+ if offset+4 > len(headers) {
+ return offset, false
+ }
+ return offset + 4, true
+ case 5: // long
+ if offset+8 > len(headers) {
+ return offset, false
+ }
+ return offset + 8, true
+ case 6: // byte array (2-byte length + data)
+ if offset+2 > len(headers) {
+ return offset, false
+ }
+ valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2]))
+ offset += 2
+ if offset+valueLen > len(headers) {
+ return offset, false
+ }
+ return offset + valueLen, true
+ case 8: // timestamp
+ if offset+8 > len(headers) {
+ return offset, false
+ }
+ return offset + 8, true
+ case 9: // uuid
+ if offset+16 > len(headers) {
+ return offset, false
+ }
+ return offset + 16, true
+ default:
+ return offset, false
+ }
+}
+
+// extractEventTypeFromBytes extracts the event type from raw header bytes (without prelude CRC prefix)
+func (e *KiroExecutor) extractEventTypeFromBytes(headers []byte) string {
+ offset := 0
+ for offset < len(headers) {
+ nameLen := int(headers[offset])
+ offset++
+ if offset+nameLen > len(headers) {
+ break
+ }
+ name := string(headers[offset : offset+nameLen])
+ offset += nameLen
+
+ if offset >= len(headers) {
+ break
+ }
+ valueType := headers[offset]
+ offset++
+
+ if valueType == 7 { // String type
+ if offset+2 > len(headers) {
+ break
+ }
+ valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2]))
+ offset += 2
+ if offset+valueLen > len(headers) {
+ break
+ }
+ value := string(headers[offset : offset+valueLen])
+ offset += valueLen
+
+ if name == ":event-type" {
+ return value
+ }
+ continue
+ }
+
+ nextOffset, ok := skipEventStreamHeaderValue(headers, offset, valueType)
+ if !ok {
+ break
+ }
+ offset = nextOffset
+ }
+ return ""
+}
+
+// NOTE: Response building functions moved to internal/translator/kiro/claude/kiro_claude_response.go
+// The executor now uses kiroclaude.BuildClaudeResponse() and kiroclaude.ExtractThinkingFromContent() instead
+
+// streamToChannel converts AWS Event Stream to channel-based streaming.
+// Supports tool calling - emits tool_use content blocks when tools are used.
+// Includes embedded [Called ...] tool call parsing and input buffering for toolUseEvent.
+// Implements duplicate content filtering using lastContentEvent detection (based on AIClient-2-API).
+// Extracts stop_reason from upstream events when available.
+// thinkingEnabled controls whether tags are parsed - only parse when request enabled thinking.
+func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter, thinkingEnabled bool) {
+ reader := bufio.NewReaderSize(body, 20*1024*1024) // 20MB buffer to match other providers
+ var totalUsage usage.Detail
+ var hasToolUses bool // Track if any tool uses were emitted
+ var hasTruncatedTools bool // Track if any tool uses were truncated
+ var upstreamStopReason string // Track stop_reason from upstream events
+
+ // Tool use state tracking for input buffering and deduplication
+ processedIDs := make(map[string]bool)
+ var currentToolUse *kiroclaude.ToolUseState
+
+ // NOTE: Duplicate content filtering removed - it was causing legitimate repeated
+ // content (like consecutive newlines) to be incorrectly filtered out.
+ // The previous implementation compared lastContentEvent == contentDelta which
+ // is too aggressive for streaming scenarios.
+
+ // Streaming token calculation - accumulate content for real-time token counting
+ // Based on AIClient-2-API implementation
+ var accumulatedContent strings.Builder
+ accumulatedContent.Grow(4096) // Pre-allocate 4KB capacity to reduce reallocations
+
+ // Real-time usage estimation state
+ // These track when to send periodic usage updates during streaming
+ var lastUsageUpdateLen int // Last accumulated content length when usage was sent
+ var lastUsageUpdateTime = time.Now() // Last time usage update was sent
+ var lastReportedOutputTokens int64 // Last reported output token count
+
+ // Upstream usage tracking - Kiro API returns credit usage and context percentage
+ var upstreamCreditUsage float64 // Credit usage from upstream (e.g., 1.458)
+ var upstreamContextPercentage float64 // Context usage percentage from upstream (e.g., 78.56)
+ var hasUpstreamUsage bool // Whether we received usage from upstream
+
+ // Translator param for maintaining tool call state across streaming events
+ // IMPORTANT: This must persist across all TranslateStream calls
+ var translatorParam any
+
+ // Thinking mode state tracking - tag-based parsing for tags in content
+ inThinkBlock := false // Whether we're currently inside a block
+ isThinkingBlockOpen := false // Track if thinking content block SSE event is open
+ thinkingBlockIndex := -1 // Index of the thinking content block
+ var accumulatedThinkingContent strings.Builder // Accumulate thinking content for token counting
+
+ // Buffer for handling partial tag matches at chunk boundaries
+ var pendingContent strings.Builder // Buffer content that might be part of a tag
+
+ // Pre-calculate input tokens from request if possible
+ // Kiro uses Claude format, so try Claude format first, then OpenAI format, then fallback
+ if enc, err := getTokenizer(model); err == nil {
+ var inputTokens int64
+ var countMethod string
+
+ // Try Claude format first (Kiro uses Claude API format)
+ if inp, err := countClaudeChatTokens(enc, claudeBody); err == nil && inp > 0 {
+ inputTokens = inp
+ countMethod = "claude"
+ } else if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 {
+ // Fallback to OpenAI format (for OpenAI-compatible requests)
+ inputTokens = inp
+ countMethod = "openai"
+ } else {
+ // Final fallback: estimate from raw request size (roughly 4 chars per token)
+ inputTokens = int64(len(claudeBody) / 4)
+ if inputTokens == 0 && len(claudeBody) > 0 {
+ inputTokens = 1
+ }
+ countMethod = "estimate"
+ }
+
+ totalUsage.InputTokens = inputTokens
+ log.Debugf("kiro: streamToChannel pre-calculated input tokens: %d (method: %s, claude body: %d bytes, original req: %d bytes)",
+ totalUsage.InputTokens, countMethod, len(claudeBody), len(originalReq))
+ }
+
+ contentBlockIndex := -1
+ messageStartSent := false
+ isTextBlockOpen := false
+ var outputLen int
+
+ // Ensure usage is published even on early return
+ defer func() {
+ reporter.publish(ctx, totalUsage)
+ }()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
+
+ msg, eventErr := e.readEventStreamMessage(reader)
+ if eventErr != nil {
+ // Log the error
+ log.Errorf("kiro: streamToChannel error: %v", eventErr)
+
+ // Send error to channel for client notification
+ out <- cliproxyexecutor.StreamChunk{Err: eventErr}
+ return
+ }
+ if msg == nil {
+ // Normal end of stream (EOF)
+ // Flush any incomplete tool use before ending stream
+ if currentToolUse != nil && !processedIDs[currentToolUse.ToolUseID] {
+ log.Warnf("kiro: flushing incomplete tool use at EOF: %s (ID: %s)", currentToolUse.Name, currentToolUse.ToolUseID)
+ fullInput := currentToolUse.InputBuffer.String()
+ repairedJSON := kiroclaude.RepairJSON(fullInput)
+ var finalInput map[string]interface{}
+ if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil {
+ log.Warnf("kiro: failed to parse incomplete tool input at EOF: %v", err)
+ finalInput = make(map[string]interface{})
+ }
+
+ processedIDs[currentToolUse.ToolUseID] = true
+ contentBlockIndex++
+
+ // Send tool_use content block
+ blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.ToolUseID, currentToolUse.Name)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+
+ // Send tool input as delta
+ inputBytes, _ := json.Marshal(finalInput)
+ inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex)
+ sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+
+ // Close block
+ blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
+ sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+
+ hasToolUses = true
+ currentToolUse = nil
+ }
+
+ // DISABLED: Tag-based pending character flushing
+ // This code block was used for tag-based thinking detection which has been
+ // replaced by reasoningContentEvent handling. No pending tag chars to flush.
+ // Original code preserved in git history.
+ break
+ }
+
+ eventType := msg.EventType
+ payload := msg.Payload
+ if len(payload) == 0 {
+ continue
+ }
+ appendAPIResponseChunk(ctx, e.cfg, payload)
+
+ var event map[string]interface{}
+ if err := json.Unmarshal(payload, &event); err != nil {
+ log.Warnf("kiro: failed to unmarshal event payload: %v, raw: %s", err, string(payload))
+ continue
+ }
+
+ // Check for error/exception events in the payload (Kiro API may return errors with HTTP 200)
+ // These can appear as top-level fields or nested within the event
+ if errType, hasErrType := event["_type"].(string); hasErrType {
+ // AWS-style error: {"_type": "com.amazon.aws.codewhisperer#ValidationException", "message": "..."}
+ errMsg := ""
+ if msg, ok := event["message"].(string); ok {
+ errMsg = msg
+ }
+ log.Errorf("kiro: received AWS error in stream: type=%s, message=%s", errType, errMsg)
+ out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("kiro API error: %s - %s", errType, errMsg)}
+ return
+ }
+ if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") {
+ // Generic error event
+ errMsg := ""
+ if msg, ok := event["message"].(string); ok {
+ errMsg = msg
+ } else if errObj, ok := event["error"].(map[string]interface{}); ok {
+ if msg, ok := errObj["message"].(string); ok {
+ errMsg = msg
+ }
+ }
+ log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg)
+ out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("kiro API error: %s", errMsg)}
+ return
+ }
+
+ // Extract stop_reason from various event formats (streaming)
+ // Kiro/Amazon Q API may include stop_reason in different locations
+ if sr := kirocommon.GetString(event, "stop_reason"); sr != "" {
+ upstreamStopReason = sr
+ log.Debugf("kiro: streamToChannel found stop_reason (top-level): %s", upstreamStopReason)
+ }
+ if sr := kirocommon.GetString(event, "stopReason"); sr != "" {
+ upstreamStopReason = sr
+ log.Debugf("kiro: streamToChannel found stopReason (top-level): %s", upstreamStopReason)
+ }
+
+ // Send message_start on first event
+ if !messageStartSent {
+ msgStart := kiroclaude.BuildClaudeMessageStartEvent(model, totalUsage.InputTokens)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ messageStartSent = true
+ }
+
+ switch eventType {
+ case "followupPromptEvent":
+ // Filter out followupPrompt events - these are UI suggestions, not content
+ log.Debugf("kiro: streamToChannel ignoring followupPrompt event")
+ continue
+
+ case "messageStopEvent", "message_stop":
+ // Handle message stop events which may contain stop_reason
+ if sr := kirocommon.GetString(event, "stop_reason"); sr != "" {
+ upstreamStopReason = sr
+ log.Debugf("kiro: streamToChannel found stop_reason in messageStopEvent: %s", upstreamStopReason)
+ }
+ if sr := kirocommon.GetString(event, "stopReason"); sr != "" {
+ upstreamStopReason = sr
+ log.Debugf("kiro: streamToChannel found stopReason in messageStopEvent: %s", upstreamStopReason)
+ }
+
+ case "meteringEvent":
+ // Handle metering events from Kiro API (usage billing information)
+ // Official format: { unit: string, unitPlural: string, usage: number }
+ if metering, ok := event["meteringEvent"].(map[string]interface{}); ok {
+ unit := ""
+ if u, ok := metering["unit"].(string); ok {
+ unit = u
+ }
+ usageVal := 0.0
+ if u, ok := metering["usage"].(float64); ok {
+ usageVal = u
+ }
+ upstreamCreditUsage = usageVal
+ hasUpstreamUsage = true
+ log.Infof("kiro: streamToChannel received meteringEvent: usage=%.4f %s", usageVal, unit)
+ } else {
+ // Try direct fields (event is meteringEvent itself)
+ if unit, ok := event["unit"].(string); ok {
+ if usage, ok := event["usage"].(float64); ok {
+ upstreamCreditUsage = usage
+ hasUpstreamUsage = true
+ log.Infof("kiro: streamToChannel received meteringEvent (direct): usage=%.4f %s", usage, unit)
+ }
+ }
+ }
+
+ case "contextUsageEvent":
+ // Handle context usage events from Kiro API
+ // Format: {"contextUsageEvent": {"contextUsagePercentage": 0.53}}
+ if ctxUsage, ok := event["contextUsageEvent"].(map[string]interface{}); ok {
+ if ctxPct, ok := ctxUsage["contextUsagePercentage"].(float64); ok {
+ upstreamContextPercentage = ctxPct
+ log.Debugf("kiro: streamToChannel received contextUsageEvent: %.2f%%", ctxPct*100)
+ }
+ } else {
+ // Try direct field (fallback)
+ if ctxPct, ok := event["contextUsagePercentage"].(float64); ok {
+ upstreamContextPercentage = ctxPct
+ log.Debugf("kiro: streamToChannel received contextUsagePercentage (direct): %.2f%%", ctxPct*100)
+ }
+ }
+
+ case "error", "exception", "internalServerException":
+ // Handle error events from Kiro API stream
+ errMsg := ""
+ errType := eventType
+
+ // Try to extract error message from various formats
+ if msg, ok := event["message"].(string); ok {
+ errMsg = msg
+ } else if errObj, ok := event[eventType].(map[string]interface{}); ok {
+ if msg, ok := errObj["message"].(string); ok {
+ errMsg = msg
+ }
+ if t, ok := errObj["type"].(string); ok {
+ errType = t
+ }
+ } else if errObj, ok := event["error"].(map[string]interface{}); ok {
+ if msg, ok := errObj["message"].(string); ok {
+ errMsg = msg
+ }
+ }
+
+ log.Errorf("kiro: streamToChannel received error event: type=%s, message=%s", errType, errMsg)
+
+ // Send error to the stream and exit
+ if errMsg != "" {
+ out <- cliproxyexecutor.StreamChunk{
+ Err: fmt.Errorf("kiro API error (%s): %s", errType, errMsg),
+ }
+ return
+ }
+
+ case "invalidStateEvent":
+ // Handle invalid state events - log and continue (non-fatal)
+ errMsg := ""
+ if msg, ok := event["message"].(string); ok {
+ errMsg = msg
+ } else if stateEvent, ok := event["invalidStateEvent"].(map[string]interface{}); ok {
+ if msg, ok := stateEvent["message"].(string); ok {
+ errMsg = msg
+ }
+ }
+ log.Warnf("kiro: streamToChannel received invalidStateEvent: %s, continuing", errMsg)
+ continue
+
+ default:
+ // Check for upstream usage events from Kiro API
+ // Format: {"unit":"credit","unitPlural":"credits","usage":1.458}
+ if unit, ok := event["unit"].(string); ok && unit == "credit" {
+ if usage, ok := event["usage"].(float64); ok {
+ upstreamCreditUsage = usage
+ hasUpstreamUsage = true
+ log.Debugf("kiro: received upstream credit usage: %.4f", upstreamCreditUsage)
+ }
+ }
+ // Format: {"contextUsagePercentage":78.56}
+ if ctxPct, ok := event["contextUsagePercentage"].(float64); ok {
+ upstreamContextPercentage = ctxPct
+ log.Debugf("kiro: received upstream context usage: %.2f%%", upstreamContextPercentage)
+ }
+
+ // Check for token counts in unknown events
+ if inputTokens, ok := event["inputTokens"].(float64); ok {
+ totalUsage.InputTokens = int64(inputTokens)
+ hasUpstreamUsage = true
+ log.Debugf("kiro: streamToChannel found inputTokens in event %s: %d", eventType, totalUsage.InputTokens)
+ }
+ if outputTokens, ok := event["outputTokens"].(float64); ok {
+ totalUsage.OutputTokens = int64(outputTokens)
+ hasUpstreamUsage = true
+ log.Debugf("kiro: streamToChannel found outputTokens in event %s: %d", eventType, totalUsage.OutputTokens)
+ }
+ if totalTokens, ok := event["totalTokens"].(float64); ok {
+ totalUsage.TotalTokens = int64(totalTokens)
+ log.Debugf("kiro: streamToChannel found totalTokens in event %s: %d", eventType, totalUsage.TotalTokens)
+ }
+
+ // Check for usage object in unknown events (OpenAI/Claude format)
+ if usageObj, ok := event["usage"].(map[string]interface{}); ok {
+ if inputTokens, ok := usageObj["input_tokens"].(float64); ok {
+ totalUsage.InputTokens = int64(inputTokens)
+ hasUpstreamUsage = true
+ } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok {
+ totalUsage.InputTokens = int64(inputTokens)
+ hasUpstreamUsage = true
+ }
+ if outputTokens, ok := usageObj["output_tokens"].(float64); ok {
+ totalUsage.OutputTokens = int64(outputTokens)
+ hasUpstreamUsage = true
+ } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok {
+ totalUsage.OutputTokens = int64(outputTokens)
+ hasUpstreamUsage = true
+ }
+ if totalTokens, ok := usageObj["total_tokens"].(float64); ok {
+ totalUsage.TotalTokens = int64(totalTokens)
+ }
+ log.Debugf("kiro: streamToChannel found usage object in event %s: input=%d, output=%d, total=%d",
+ eventType, totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens)
+ }
+
+ // Log unknown event types for debugging (to discover new event formats)
+ if eventType != "" {
+ log.Debugf("kiro: streamToChannel unknown event type: %s, payload: %s", eventType, string(payload))
+ }
+
+ case "assistantResponseEvent":
+ var contentDelta string
+ var toolUses []map[string]interface{}
+
+ if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok {
+ if c, ok := assistantResp["content"].(string); ok {
+ contentDelta = c
+ }
+ // Extract stop_reason from assistantResponseEvent
+ if sr := kirocommon.GetString(assistantResp, "stop_reason"); sr != "" {
+ upstreamStopReason = sr
+ log.Debugf("kiro: streamToChannel found stop_reason in assistantResponseEvent: %s", upstreamStopReason)
+ }
+ if sr := kirocommon.GetString(assistantResp, "stopReason"); sr != "" {
+ upstreamStopReason = sr
+ log.Debugf("kiro: streamToChannel found stopReason in assistantResponseEvent: %s", upstreamStopReason)
+ }
+ // Extract tool uses from response
+ if tus, ok := assistantResp["toolUses"].([]interface{}); ok {
+ for _, tuRaw := range tus {
+ if tu, ok := tuRaw.(map[string]interface{}); ok {
+ toolUses = append(toolUses, tu)
+ }
+ }
+ }
+ }
+ if contentDelta == "" {
+ if c, ok := event["content"].(string); ok {
+ contentDelta = c
+ }
+ }
+ // Direct tool uses
+ if tus, ok := event["toolUses"].([]interface{}); ok {
+ for _, tuRaw := range tus {
+ if tu, ok := tuRaw.(map[string]interface{}); ok {
+ toolUses = append(toolUses, tu)
+ }
+ }
+ }
+
+ // Handle text content with thinking mode support
+ if contentDelta != "" {
+ // NOTE: Duplicate content filtering was removed because it incorrectly
+ // filtered out legitimate repeated content (like consecutive newlines "\n\n").
+ // Streaming naturally can have identical chunks that are valid content.
+
+ outputLen += len(contentDelta)
+ // Accumulate content for streaming token calculation
+ accumulatedContent.WriteString(contentDelta)
+
+ // Real-time usage estimation: Check if we should send a usage update
+ // This helps clients track context usage during long thinking sessions
+ shouldSendUsageUpdate := false
+ if accumulatedContent.Len()-lastUsageUpdateLen >= usageUpdateCharThreshold {
+ shouldSendUsageUpdate = true
+ } else if time.Since(lastUsageUpdateTime) >= usageUpdateTimeInterval && accumulatedContent.Len() > lastUsageUpdateLen {
+ shouldSendUsageUpdate = true
+ }
+
+ if shouldSendUsageUpdate {
+ // Calculate current output tokens using tiktoken
+ var currentOutputTokens int64
+ if enc, encErr := getTokenizer(model); encErr == nil {
+ if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil {
+ currentOutputTokens = int64(tokenCount)
+ }
+ }
+ // Fallback to character estimation if tiktoken fails
+ if currentOutputTokens == 0 {
+ currentOutputTokens = int64(accumulatedContent.Len() / 4)
+ if currentOutputTokens == 0 {
+ currentOutputTokens = 1
+ }
+ }
+
+ // Only send update if token count has changed significantly (at least 10 tokens)
+ if currentOutputTokens > lastReportedOutputTokens+10 {
+ // Send ping event with usage information
+ // This is a non-blocking update that clients can optionally process
+ pingEvent := kiroclaude.BuildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+
+ lastReportedOutputTokens = currentOutputTokens
+ log.Debugf("kiro: sent real-time usage update - input: %d, output: %d (accumulated: %d chars)",
+ totalUsage.InputTokens, currentOutputTokens, accumulatedContent.Len())
+ }
+
+ lastUsageUpdateLen = accumulatedContent.Len()
+ lastUsageUpdateTime = time.Now()
+ }
+
+ // TAG-BASED THINKING PARSING: Parse tags from content
+ // Combine pending content with new content for processing
+ pendingContent.WriteString(contentDelta)
+ processContent := pendingContent.String()
+ pendingContent.Reset()
+
+ // Process content looking for thinking tags
+ for len(processContent) > 0 {
+ if inThinkBlock {
+ // We're inside a thinking block, look for
+ endIdx := strings.Index(processContent, kirocommon.ThinkingEndTag)
+ if endIdx >= 0 {
+ // Found end tag - emit thinking content before the tag
+ thinkingText := processContent[:endIdx]
+ if thinkingText != "" {
+ // Ensure thinking block is open
+ if !isThinkingBlockOpen {
+ contentBlockIndex++
+ thinkingBlockIndex = contentBlockIndex
+ isThinkingBlockOpen = true
+ blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ }
+ // Send thinking delta
+ thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ accumulatedThinkingContent.WriteString(thinkingText)
+ }
+ // Close thinking block
+ if isThinkingBlockOpen {
+ blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ isThinkingBlockOpen = false
+ }
+ inThinkBlock = false
+ processContent = processContent[endIdx+len(kirocommon.ThinkingEndTag):]
+ log.Debugf("kiro: closed thinking block, remaining content: %d chars", len(processContent))
+ } else {
+ // No end tag found - check for partial match at end
+ partialMatch := false
+ for i := 1; i < len(kirocommon.ThinkingEndTag) && i <= len(processContent); i++ {
+ if strings.HasSuffix(processContent, kirocommon.ThinkingEndTag[:i]) {
+ // Possible partial tag at end, buffer it
+ pendingContent.WriteString(processContent[len(processContent)-i:])
+ processContent = processContent[:len(processContent)-i]
+ partialMatch = true
+ break
+ }
+ }
+ if !partialMatch || len(processContent) > 0 {
+ // Emit all as thinking content
+ if processContent != "" {
+ if !isThinkingBlockOpen {
+ contentBlockIndex++
+ thinkingBlockIndex = contentBlockIndex
+ isThinkingBlockOpen = true
+ blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ }
+ thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(processContent, thinkingBlockIndex)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ accumulatedThinkingContent.WriteString(processContent)
+ }
+ }
+ processContent = ""
+ }
+ } else {
+ // Not in thinking block, look for
+ startIdx := strings.Index(processContent, kirocommon.ThinkingStartTag)
+ if startIdx >= 0 {
+ // Found start tag - emit text content before the tag
+ textBefore := processContent[:startIdx]
+ if textBefore != "" {
+ // Close thinking block if open
+ if isThinkingBlockOpen {
+ blockStop := kiroclaude.BuildClaudeThinkingBlockStopEvent(thinkingBlockIndex)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ isThinkingBlockOpen = false
+ }
+ // Ensure text block is open
+ if !isTextBlockOpen {
+ contentBlockIndex++
+ isTextBlockOpen = true
+ blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ }
+ // Send text delta
+ claudeEvent := kiroclaude.BuildClaudeStreamEvent(textBefore, contentBlockIndex)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ }
+ // Close text block before entering thinking
+ if isTextBlockOpen {
+ blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ isTextBlockOpen = false
+ }
+ inThinkBlock = true
+ processContent = processContent[startIdx+len(kirocommon.ThinkingStartTag):]
+ log.Debugf("kiro: entered thinking block")
+ } else {
+ // No start tag found - check for partial match at end
+ partialMatch := false
+ for i := 1; i < len(kirocommon.ThinkingStartTag) && i <= len(processContent); i++ {
+ if strings.HasSuffix(processContent, kirocommon.ThinkingStartTag[:i]) {
+ // Possible partial tag at end, buffer it
+ pendingContent.WriteString(processContent[len(processContent)-i:])
+ processContent = processContent[:len(processContent)-i]
+ partialMatch = true
+ break
+ }
+ }
+ if !partialMatch || len(processContent) > 0 {
+ // Emit all as text content
+ if processContent != "" {
+ if !isTextBlockOpen {
+ contentBlockIndex++
+ isTextBlockOpen = true
+ blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "")
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ }
+ claudeEvent := kiroclaude.BuildClaudeStreamEvent(processContent, contentBlockIndex)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ }
+ }
+ processContent = ""
+ }
+ }
+ }
+ }
+
+ // Handle tool uses in response (with deduplication)
+ for _, tu := range toolUses {
+ toolUseID := kirocommon.GetString(tu, "toolUseId")
+ toolName := kirocommon.GetString(tu, "name")
+
+ // Check for duplicate
+ if processedIDs[toolUseID] {
+ log.Debugf("kiro: skipping duplicate tool use in stream: %s", toolUseID)
+ continue
+ }
+ processedIDs[toolUseID] = true
+
+ hasToolUses = true
+ // Close text block if open before starting tool_use block
+ if isTextBlockOpen && contentBlockIndex >= 0 {
+ blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ isTextBlockOpen = false
+ }
+
+ // Emit tool_use content block
+ contentBlockIndex++
+
+ blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+
+ // Send input_json_delta with the tool input
+ if input, ok := tu["input"].(map[string]interface{}); ok {
+ inputJSON, err := json.Marshal(input)
+ if err != nil {
+ log.Debugf("kiro: failed to marshal tool input: %v", err)
+ // Don't continue - still need to close the block
+ } else {
+ inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
+ sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ }
+ }
+
+ // Close tool_use block (always close even if input marshal failed)
+ blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
+ sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ }
+
+ case "reasoningContentEvent":
+ // Handle official reasoningContentEvent from Kiro API
+ // This replaces tag-based thinking detection with the proper event type
+ // Official format: { text: string, signature?: string, redactedContent?: base64 }
+ var thinkingText string
+ var signature string
+
+ if re, ok := event["reasoningContentEvent"].(map[string]interface{}); ok {
+ if text, ok := re["text"].(string); ok {
+ thinkingText = text
+ }
+ if sig, ok := re["signature"].(string); ok {
+ signature = sig
+ if len(sig) > 20 {
+ log.Debugf("kiro: reasoningContentEvent has signature: %s...", sig[:20])
+ } else {
+ log.Debugf("kiro: reasoningContentEvent has signature: %s", sig)
+ }
+ }
+ } else {
+ // Try direct fields
+ if text, ok := event["text"].(string); ok {
+ thinkingText = text
+ }
+ if sig, ok := event["signature"].(string); ok {
+ signature = sig
+ }
+ }
+
+ if thinkingText != "" {
+ // Close text block if open before starting thinking block
+ if isTextBlockOpen && contentBlockIndex >= 0 {
+ blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ isTextBlockOpen = false
+ }
+
+ // Start thinking block if not already open
+ if !isThinkingBlockOpen {
+ contentBlockIndex++
+ thinkingBlockIndex = contentBlockIndex
+ isThinkingBlockOpen = true
+ blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "")
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ }
+
+ // Send thinking content
+ thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkingText, thinkingBlockIndex)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+
+ // Accumulate for token counting
+ accumulatedThinkingContent.WriteString(thinkingText)
+ log.Debugf("kiro: received reasoningContentEvent, text length: %d, has signature: %v", len(thinkingText), signature != "")
+ }
+
+ // Note: We don't close the thinking block here - it will be closed when we see
+ // the next assistantResponseEvent or at the end of the stream
+ _ = signature // Signature can be used for verification if needed
+
+ case "toolUseEvent":
+ // Handle dedicated tool use events with input buffering
+ completedToolUses, newState := kiroclaude.ProcessToolUseEvent(event, currentToolUse, processedIDs)
+ currentToolUse = newState
+
+ // Emit completed tool uses
+ for _, tu := range completedToolUses {
+ // Check if this tool was truncated - emit with SOFT_LIMIT_REACHED marker
+ if tu.IsTruncated {
+ hasTruncatedTools = true
+ log.Infof("kiro: streamToChannel emitting truncated tool with SOFT_LIMIT_REACHED: %s (ID: %s)", tu.Name, tu.ToolUseID)
+
+ // Close text block if open
+ if isTextBlockOpen && contentBlockIndex >= 0 {
+ blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ isTextBlockOpen = false
+ }
+
+ contentBlockIndex++
+
+ // Emit tool_use with SOFT_LIMIT_REACHED marker input
+ blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+
+ // Build SOFT_LIMIT_REACHED marker input
+ markerInput := map[string]interface{}{
+ "_status": "SOFT_LIMIT_REACHED",
+ "_message": "Tool output was truncated. Split content into smaller chunks (max 300 lines). Due to potential model hallucination, you MUST re-fetch the current working directory and generate the correct file_path.",
+ }
+
+ markerJSON, _ := json.Marshal(markerInput)
+ inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(markerJSON), contentBlockIndex)
+ sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+
+ // Close tool_use block
+ blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
+ sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+
+ hasToolUses = true // Keep this so stop_reason = tool_use
+ continue
+ }
+
+ hasToolUses = true
+
+ // Close text block if open
+ if isTextBlockOpen && contentBlockIndex >= 0 {
+ blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ isTextBlockOpen = false
+ }
+
+ contentBlockIndex++
+
+ blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+
+ if tu.Input != nil {
+ inputJSON, err := json.Marshal(tu.Input)
+ if err != nil {
+ log.Debugf("kiro: failed to marshal tool input in toolUseEvent: %v", err)
+ } else {
+ inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex)
+ sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ }
+ }
+
+ blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
+ sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ }
+
+ case "supplementaryWebLinksEvent":
+ if inputTokens, ok := event["inputTokens"].(float64); ok {
+ totalUsage.InputTokens = int64(inputTokens)
+ }
+ if outputTokens, ok := event["outputTokens"].(float64); ok {
+ totalUsage.OutputTokens = int64(outputTokens)
+ }
+
+ case "messageMetadataEvent", "metadataEvent":
+ // Handle message metadata events which contain token counts
+ // Official format: { tokenUsage: { outputTokens, totalTokens, uncachedInputTokens, cacheReadInputTokens, cacheWriteInputTokens, contextUsagePercentage } }
+ var metadata map[string]interface{}
+ if m, ok := event["messageMetadataEvent"].(map[string]interface{}); ok {
+ metadata = m
+ } else if m, ok := event["metadataEvent"].(map[string]interface{}); ok {
+ metadata = m
+ } else {
+ metadata = event // event itself might be the metadata
+ }
+
+ // Check for nested tokenUsage object (official format)
+ if tokenUsage, ok := metadata["tokenUsage"].(map[string]interface{}); ok {
+ // outputTokens - precise output token count
+ if outputTokens, ok := tokenUsage["outputTokens"].(float64); ok {
+ totalUsage.OutputTokens = int64(outputTokens)
+ hasUpstreamUsage = true
+ log.Infof("kiro: streamToChannel found precise outputTokens in tokenUsage: %d", totalUsage.OutputTokens)
+ }
+ // totalTokens - precise total token count
+ if totalTokens, ok := tokenUsage["totalTokens"].(float64); ok {
+ totalUsage.TotalTokens = int64(totalTokens)
+ log.Infof("kiro: streamToChannel found precise totalTokens in tokenUsage: %d", totalUsage.TotalTokens)
+ }
+ // uncachedInputTokens - input tokens not from cache
+ if uncachedInputTokens, ok := tokenUsage["uncachedInputTokens"].(float64); ok {
+ totalUsage.InputTokens = int64(uncachedInputTokens)
+ hasUpstreamUsage = true
+ log.Infof("kiro: streamToChannel found uncachedInputTokens in tokenUsage: %d", totalUsage.InputTokens)
+ }
+ // cacheReadInputTokens - tokens read from cache
+ if cacheReadTokens, ok := tokenUsage["cacheReadInputTokens"].(float64); ok {
+ // Add to input tokens if we have uncached tokens, otherwise use as input
+ if totalUsage.InputTokens > 0 {
+ totalUsage.InputTokens += int64(cacheReadTokens)
+ } else {
+ totalUsage.InputTokens = int64(cacheReadTokens)
+ }
+ hasUpstreamUsage = true
+ log.Debugf("kiro: streamToChannel found cacheReadInputTokens in tokenUsage: %d", int64(cacheReadTokens))
+ }
+ // contextUsagePercentage - can be used as fallback for input token estimation
+ if ctxPct, ok := tokenUsage["contextUsagePercentage"].(float64); ok {
+ upstreamContextPercentage = ctxPct
+ log.Debugf("kiro: streamToChannel found contextUsagePercentage in tokenUsage: %.2f%%", ctxPct)
+ }
+ }
+
+ // Fallback: check for direct fields in metadata (legacy format)
+ if totalUsage.InputTokens == 0 {
+ if inputTokens, ok := metadata["inputTokens"].(float64); ok {
+ totalUsage.InputTokens = int64(inputTokens)
+ hasUpstreamUsage = true
+ log.Debugf("kiro: streamToChannel found inputTokens in messageMetadataEvent: %d", totalUsage.InputTokens)
+ }
+ }
+ if totalUsage.OutputTokens == 0 {
+ if outputTokens, ok := metadata["outputTokens"].(float64); ok {
+ totalUsage.OutputTokens = int64(outputTokens)
+ hasUpstreamUsage = true
+ log.Debugf("kiro: streamToChannel found outputTokens in messageMetadataEvent: %d", totalUsage.OutputTokens)
+ }
+ }
+ if totalUsage.TotalTokens == 0 {
+ if totalTokens, ok := metadata["totalTokens"].(float64); ok {
+ totalUsage.TotalTokens = int64(totalTokens)
+ log.Debugf("kiro: streamToChannel found totalTokens in messageMetadataEvent: %d", totalUsage.TotalTokens)
+ }
+ }
+
+ case "usageEvent", "usage":
+ // Handle dedicated usage events
+ if inputTokens, ok := event["inputTokens"].(float64); ok {
+ totalUsage.InputTokens = int64(inputTokens)
+ log.Debugf("kiro: streamToChannel found inputTokens in usageEvent: %d", totalUsage.InputTokens)
+ }
+ if outputTokens, ok := event["outputTokens"].(float64); ok {
+ totalUsage.OutputTokens = int64(outputTokens)
+ log.Debugf("kiro: streamToChannel found outputTokens in usageEvent: %d", totalUsage.OutputTokens)
+ }
+ if totalTokens, ok := event["totalTokens"].(float64); ok {
+ totalUsage.TotalTokens = int64(totalTokens)
+ log.Debugf("kiro: streamToChannel found totalTokens in usageEvent: %d", totalUsage.TotalTokens)
+ }
+ // Also check nested usage object
+ if usageObj, ok := event["usage"].(map[string]interface{}); ok {
+ if inputTokens, ok := usageObj["input_tokens"].(float64); ok {
+ totalUsage.InputTokens = int64(inputTokens)
+ } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok {
+ totalUsage.InputTokens = int64(inputTokens)
+ }
+ if outputTokens, ok := usageObj["output_tokens"].(float64); ok {
+ totalUsage.OutputTokens = int64(outputTokens)
+ } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok {
+ totalUsage.OutputTokens = int64(outputTokens)
+ }
+ if totalTokens, ok := usageObj["total_tokens"].(float64); ok {
+ totalUsage.TotalTokens = int64(totalTokens)
+ }
+ log.Debugf("kiro: streamToChannel found usage object: input=%d, output=%d, total=%d",
+ totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens)
+ }
+
+ case "metricsEvent":
+ // Handle metrics events which may contain usage data
+ if metrics, ok := event["metricsEvent"].(map[string]interface{}); ok {
+ if inputTokens, ok := metrics["inputTokens"].(float64); ok {
+ totalUsage.InputTokens = int64(inputTokens)
+ }
+ if outputTokens, ok := metrics["outputTokens"].(float64); ok {
+ totalUsage.OutputTokens = int64(outputTokens)
+ }
+ log.Debugf("kiro: streamToChannel found metricsEvent: input=%d, output=%d",
+ totalUsage.InputTokens, totalUsage.OutputTokens)
+ }
+ }
+
+ // Check nested usage event
+ if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok {
+ if inputTokens, ok := usageEvent["inputTokens"].(float64); ok {
+ totalUsage.InputTokens = int64(inputTokens)
+ }
+ if outputTokens, ok := usageEvent["outputTokens"].(float64); ok {
+ totalUsage.OutputTokens = int64(outputTokens)
+ }
+ }
+
+ // Check for direct token fields in any event (fallback)
+ if totalUsage.InputTokens == 0 {
+ if inputTokens, ok := event["inputTokens"].(float64); ok {
+ totalUsage.InputTokens = int64(inputTokens)
+ log.Debugf("kiro: streamToChannel found direct inputTokens: %d", totalUsage.InputTokens)
+ }
+ }
+ if totalUsage.OutputTokens == 0 {
+ if outputTokens, ok := event["outputTokens"].(float64); ok {
+ totalUsage.OutputTokens = int64(outputTokens)
+ log.Debugf("kiro: streamToChannel found direct outputTokens: %d", totalUsage.OutputTokens)
+ }
+ }
+
+ // Check for usage object in any event (OpenAI format)
+ if totalUsage.InputTokens == 0 || totalUsage.OutputTokens == 0 {
+ if usageObj, ok := event["usage"].(map[string]interface{}); ok {
+ if totalUsage.InputTokens == 0 {
+ if inputTokens, ok := usageObj["input_tokens"].(float64); ok {
+ totalUsage.InputTokens = int64(inputTokens)
+ } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok {
+ totalUsage.InputTokens = int64(inputTokens)
+ }
+ }
+ if totalUsage.OutputTokens == 0 {
+ if outputTokens, ok := usageObj["output_tokens"].(float64); ok {
+ totalUsage.OutputTokens = int64(outputTokens)
+ } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok {
+ totalUsage.OutputTokens = int64(outputTokens)
+ }
+ }
+ if totalUsage.TotalTokens == 0 {
+ if totalTokens, ok := usageObj["total_tokens"].(float64); ok {
+ totalUsage.TotalTokens = int64(totalTokens)
+ }
+ }
+ log.Debugf("kiro: streamToChannel found usage object (fallback): input=%d, output=%d, total=%d",
+ totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens)
+ }
+ }
+ }
+
+ // Close content block if open
+ if isTextBlockOpen && contentBlockIndex >= 0 {
+ blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ }
+
+ // Streaming token calculation - calculate output tokens from accumulated content
+ // Only use local estimation if server didn't provide usage (server-side usage takes priority)
+ if totalUsage.OutputTokens == 0 && accumulatedContent.Len() > 0 {
+ // Try to use tiktoken for accurate counting
+ if enc, err := getTokenizer(model); err == nil {
+ if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil {
+ totalUsage.OutputTokens = int64(tokenCount)
+ log.Debugf("kiro: streamToChannel calculated output tokens using tiktoken: %d", totalUsage.OutputTokens)
+ } else {
+ // Fallback on count error: estimate from character count
+ totalUsage.OutputTokens = int64(accumulatedContent.Len() / 4)
+ if totalUsage.OutputTokens == 0 {
+ totalUsage.OutputTokens = 1
+ }
+ log.Debugf("kiro: streamToChannel tiktoken count failed, estimated from chars: %d", totalUsage.OutputTokens)
+ }
+ } else {
+ // Fallback: estimate from character count (roughly 4 chars per token)
+ totalUsage.OutputTokens = int64(accumulatedContent.Len() / 4)
+ if totalUsage.OutputTokens == 0 {
+ totalUsage.OutputTokens = 1
+ }
+ log.Debugf("kiro: streamToChannel estimated output tokens from chars: %d (content len: %d)", totalUsage.OutputTokens, accumulatedContent.Len())
+ }
+ } else if totalUsage.OutputTokens == 0 && outputLen > 0 {
+ // Legacy fallback using outputLen
+ totalUsage.OutputTokens = int64(outputLen / 4)
+ if totalUsage.OutputTokens == 0 {
+ totalUsage.OutputTokens = 1
+ }
+ }
+
+ // Use contextUsagePercentage to calculate more accurate input tokens
+ // Kiro model has 200k max context, contextUsagePercentage represents the percentage used
+ // Formula: input_tokens = contextUsagePercentage * 200000 / 100
+ // Note: The effective input context is ~170k (200k - 30k reserved for output)
+ if upstreamContextPercentage > 0 {
+ // Calculate input tokens from context percentage
+ // Using 200k as the base since that's what Kiro reports against
+ calculatedInputTokens := int64(upstreamContextPercentage * 200000 / 100)
+
+ // Only use calculated value if it's significantly different from local estimate
+ // This provides more accurate token counts based on upstream data
+ if calculatedInputTokens > 0 {
+ localEstimate := totalUsage.InputTokens
+ totalUsage.InputTokens = calculatedInputTokens
+ log.Debugf("kiro: using contextUsagePercentage (%.2f%%) to calculate input tokens: %d (local estimate was: %d)",
+ upstreamContextPercentage, calculatedInputTokens, localEstimate)
+ }
+ }
+
+ totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens
+
+ // Log upstream usage information if received
+ if hasUpstreamUsage {
+ log.Debugf("kiro: upstream usage - credits: %.4f, context: %.2f%%, final tokens - input: %d, output: %d, total: %d",
+ upstreamCreditUsage, upstreamContextPercentage,
+ totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens)
+ }
+
+ // Determine stop reason: prefer upstream, then detect tool_use, default to end_turn
+ // SOFT_LIMIT_REACHED: Keep stop_reason = "tool_use" so Claude continues the loop
+ stopReason := upstreamStopReason
+ if hasTruncatedTools {
+ // Log that we're using SOFT_LIMIT_REACHED approach
+ log.Infof("kiro: streamToChannel using SOFT_LIMIT_REACHED - keeping stop_reason=tool_use for truncated tools")
+ }
+ if stopReason == "" {
+ if hasToolUses {
+ stopReason = "tool_use"
+ log.Debugf("kiro: streamToChannel using fallback stop_reason: tool_use")
+ } else {
+ stopReason = "end_turn"
+ log.Debugf("kiro: streamToChannel using fallback stop_reason: end_turn")
+ }
+ }
+
+ // Log warning if response was truncated due to max_tokens
+ if stopReason == "max_tokens" {
+ log.Warnf("kiro: response truncated due to max_tokens limit (streamToChannel)")
+ }
+
+ // Send message_delta event
+ msgDelta := kiroclaude.BuildClaudeMessageDeltaEvent(stopReason, totalUsage)
+ sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+
+ // Send message_stop event separately
+ msgStop := kiroclaude.BuildClaudeMessageStopOnlyEvent()
+ sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam)
+ for _, chunk := range sseData {
+ if chunk != "" {
+ out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")}
+ }
+ }
+ // reporter.publish is called via defer
+}
+
+// NOTE: Claude SSE event builders moved to internal/translator/kiro/claude/kiro_claude_stream.go
+// The executor now uses kiroclaude.BuildClaude*Event() functions instead
+
+// CountTokens counts tokens locally using tiktoken since Kiro API doesn't expose a token counting endpoint.
+// This provides approximate token counts for client requests.
+func (e *KiroExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
+ // Use tiktoken for local token counting
+ enc, err := getTokenizer(req.Model)
+ if err != nil {
+ log.Warnf("kiro: CountTokens failed to get tokenizer: %v, falling back to estimate", err)
+ // Fallback: estimate from payload size (roughly 4 chars per token)
+ estimatedTokens := len(req.Payload) / 4
+ if estimatedTokens == 0 && len(req.Payload) > 0 {
+ estimatedTokens = 1
+ }
+ return cliproxyexecutor.Response{
+ Payload: []byte(fmt.Sprintf(`{"count":%d}`, estimatedTokens)),
+ }, nil
+ }
+
+ // Try to count tokens from the request payload
+ var totalTokens int64
+
+ // Try OpenAI chat format first
+ if tokens, countErr := countOpenAIChatTokens(enc, req.Payload); countErr == nil && tokens > 0 {
+ totalTokens = tokens
+ log.Debugf("kiro: CountTokens counted %d tokens using OpenAI chat format", totalTokens)
+ } else {
+ // Fallback: count raw payload tokens
+ if tokenCount, countErr := enc.Count(string(req.Payload)); countErr == nil {
+ totalTokens = int64(tokenCount)
+ log.Debugf("kiro: CountTokens counted %d tokens from raw payload", totalTokens)
+ } else {
+ // Final fallback: estimate from payload size
+ totalTokens = int64(len(req.Payload) / 4)
+ if totalTokens == 0 && len(req.Payload) > 0 {
+ totalTokens = 1
+ }
+ log.Debugf("kiro: CountTokens estimated %d tokens from payload size", totalTokens)
+ }
+ }
+
+ return cliproxyexecutor.Response{
+ Payload: []byte(fmt.Sprintf(`{"count":%d}`, totalTokens)),
+ }, nil
+}
+
+// Refresh refreshes the Kiro OAuth token.
+// Supports both AWS Builder ID (SSO OIDC) and Google OAuth (social login).
+// Uses mutex to prevent race conditions when multiple concurrent requests try to refresh.
+func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
+ // Serialize token refresh operations to prevent race conditions
+ e.refreshMu.Lock()
+ defer e.refreshMu.Unlock()
+
+ var authID string
+ if auth != nil {
+ authID = auth.ID
+ } else {
+ authID = ""
+ }
+ log.Debugf("kiro executor: refresh called for auth %s", authID)
+ if auth == nil {
+ return nil, fmt.Errorf("kiro executor: auth is nil")
+ }
+
+ // Double-check: After acquiring lock, verify token still needs refresh
+ // Another goroutine may have already refreshed while we were waiting
+ // NOTE: This check has a design limitation - it reads from the auth object passed in,
+ // not from persistent storage. If another goroutine returns a new Auth object (via Clone),
+ // this check won't see those updates. The mutex still prevents truly concurrent refreshes,
+ // but queued goroutines may still attempt redundant refreshes. This is acceptable as
+ // the refresh operation is idempotent and the extra API calls are infrequent.
+ if auth.Metadata != nil {
+ if lastRefresh, ok := auth.Metadata["last_refresh"].(string); ok {
+ if refreshTime, err := time.Parse(time.RFC3339, lastRefresh); err == nil {
+ // If token was refreshed within the last 30 seconds, skip refresh
+ if time.Since(refreshTime) < 30*time.Second {
+ log.Debugf("kiro executor: token was recently refreshed by another goroutine, skipping")
+ return auth, nil
+ }
+ }
+ }
+ // Also check if expires_at is now in the future with sufficient buffer
+ if expiresAt, ok := auth.Metadata["expires_at"].(string); ok {
+ if expTime, err := time.Parse(time.RFC3339, expiresAt); err == nil {
+ // If token expires more than 20 minutes from now, it's still valid
+ if time.Until(expTime) > 20*time.Minute {
+ log.Debugf("kiro executor: token is still valid (expires in %v), skipping refresh", time.Until(expTime))
+ // CRITICAL FIX: Set NextRefreshAfter to prevent frequent refresh checks
+ // Without this, shouldRefresh() will return true again in 30 seconds
+ updated := auth.Clone()
+ // Set next refresh to 20 minutes before expiry, or at least 30 seconds from now
+ nextRefresh := expTime.Add(-20 * time.Minute)
+ minNextRefresh := time.Now().Add(30 * time.Second)
+ if nextRefresh.Before(minNextRefresh) {
+ nextRefresh = minNextRefresh
+ }
+ updated.NextRefreshAfter = nextRefresh
+ log.Debugf("kiro executor: setting NextRefreshAfter to %v (in %v)", nextRefresh.Format(time.RFC3339), time.Until(nextRefresh))
+ return updated, nil
+ }
+ }
+ }
+ }
+
+ var refreshToken string
+ var clientID, clientSecret string
+ var authMethod string
+ var region, startURL string
+
+ if auth.Metadata != nil {
+ if rt, ok := auth.Metadata["refresh_token"].(string); ok {
+ refreshToken = rt
+ }
+ if cid, ok := auth.Metadata["client_id"].(string); ok {
+ clientID = cid
+ }
+ if cs, ok := auth.Metadata["client_secret"].(string); ok {
+ clientSecret = cs
+ }
+ if am, ok := auth.Metadata["auth_method"].(string); ok {
+ authMethod = am
+ }
+ if r, ok := auth.Metadata["region"].(string); ok {
+ region = r
+ }
+ if su, ok := auth.Metadata["start_url"].(string); ok {
+ startURL = su
+ }
+ }
+
+ if refreshToken == "" {
+ return nil, fmt.Errorf("kiro executor: refresh token not found")
+ }
+
+ var tokenData *kiroauth.KiroTokenData
+ var err error
+
+ ssoClient := kiroauth.NewSSOOIDCClient(e.cfg)
+
+ // Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint
+ switch {
+ case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "":
+ // IDC refresh with region-specific endpoint
+ log.Debugf("kiro executor: using SSO OIDC refresh for IDC (region=%s)", region)
+ tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL)
+ case clientID != "" && clientSecret != "" && authMethod == "builder-id":
+ // Builder ID refresh with default endpoint
+ log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID")
+ tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
+ default:
+ // Fallback to Kiro's OAuth refresh endpoint (for social auth: Google/GitHub)
+ log.Debugf("kiro executor: using Kiro OAuth refresh endpoint")
+ oauth := kiroauth.NewKiroOAuth(e.cfg)
+ tokenData, err = oauth.RefreshToken(ctx, refreshToken)
+ }
+
+ if err != nil {
+ return nil, fmt.Errorf("kiro executor: token refresh failed: %w", err)
+ }
+
+ updated := auth.Clone()
+ now := time.Now()
+ updated.UpdatedAt = now
+ updated.LastRefreshedAt = now
+
+ if updated.Metadata == nil {
+ updated.Metadata = make(map[string]any)
+ }
+ updated.Metadata["access_token"] = tokenData.AccessToken
+ updated.Metadata["refresh_token"] = tokenData.RefreshToken
+ updated.Metadata["expires_at"] = tokenData.ExpiresAt
+ updated.Metadata["last_refresh"] = now.Format(time.RFC3339)
+ if tokenData.ProfileArn != "" {
+ updated.Metadata["profile_arn"] = tokenData.ProfileArn
+ }
+ if tokenData.AuthMethod != "" {
+ updated.Metadata["auth_method"] = tokenData.AuthMethod
+ }
+ if tokenData.Provider != "" {
+ updated.Metadata["provider"] = tokenData.Provider
+ }
+ // Preserve client credentials for future refreshes (AWS Builder ID)
+ if tokenData.ClientID != "" {
+ updated.Metadata["client_id"] = tokenData.ClientID
+ }
+ if tokenData.ClientSecret != "" {
+ updated.Metadata["client_secret"] = tokenData.ClientSecret
+ }
+ // Preserve region and start_url for IDC token refresh
+ if tokenData.Region != "" {
+ updated.Metadata["region"] = tokenData.Region
+ }
+ if tokenData.StartURL != "" {
+ updated.Metadata["start_url"] = tokenData.StartURL
+ }
+
+ if updated.Attributes == nil {
+ updated.Attributes = make(map[string]string)
+ }
+ updated.Attributes["access_token"] = tokenData.AccessToken
+ if tokenData.ProfileArn != "" {
+ updated.Attributes["profile_arn"] = tokenData.ProfileArn
+ }
+
+ // NextRefreshAfter is aligned with RefreshLead (20min)
+ if expiresAt, parseErr := time.Parse(time.RFC3339, tokenData.ExpiresAt); parseErr == nil {
+ updated.NextRefreshAfter = expiresAt.Add(-20 * time.Minute)
+ }
+
+ log.Infof("kiro executor: token refreshed successfully, expires at %s", tokenData.ExpiresAt)
+ return updated, nil
+}
+
+// persistRefreshedAuth persists a refreshed auth record to disk.
+// This ensures token refreshes from inline retry are saved to the auth file.
+func (e *KiroExecutor) persistRefreshedAuth(auth *cliproxyauth.Auth) error {
+ if auth == nil || auth.Metadata == nil {
+ return fmt.Errorf("kiro executor: cannot persist nil auth or metadata")
+ }
+
+ // Determine the file path from auth attributes or filename
+ var authPath string
+ if auth.Attributes != nil {
+ if p := strings.TrimSpace(auth.Attributes["path"]); p != "" {
+ authPath = p
+ }
+ }
+ if authPath == "" {
+ fileName := strings.TrimSpace(auth.FileName)
+ if fileName == "" {
+ return fmt.Errorf("kiro executor: auth has no file path or filename")
+ }
+ if filepath.IsAbs(fileName) {
+ authPath = fileName
+ } else if e.cfg != nil && e.cfg.AuthDir != "" {
+ authPath = filepath.Join(e.cfg.AuthDir, fileName)
+ } else {
+ return fmt.Errorf("kiro executor: cannot determine auth file path")
+ }
+ }
+
+ // Marshal metadata to JSON
+ raw, err := json.Marshal(auth.Metadata)
+ if err != nil {
+ return fmt.Errorf("kiro executor: marshal metadata failed: %w", err)
+ }
+
+ // Write to temp file first, then rename (atomic write)
+ tmp := authPath + ".tmp"
+ if err := os.WriteFile(tmp, raw, 0o600); err != nil {
+ return fmt.Errorf("kiro executor: write temp auth file failed: %w", err)
+ }
+ if err := os.Rename(tmp, authPath); err != nil {
+ return fmt.Errorf("kiro executor: rename auth file failed: %w", err)
+ }
+
+ log.Debugf("kiro executor: persisted refreshed auth to %s", authPath)
+ return nil
+}
+
+// reloadAuthFromFile 从文件重新加载 auth 数据(方案 B: Fallback 机制)
+// 当内存中的 token 已过期时,尝试从文件读取最新的 token
+// 这解决了后台刷新器已更新文件但内存中 Auth 对象尚未同步的时间差问题
+func (e *KiroExecutor) reloadAuthFromFile(auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
+ if auth == nil {
+ return nil, fmt.Errorf("kiro executor: cannot reload nil auth")
+ }
+
+ // 确定文件路径
+ var authPath string
+ if auth.Attributes != nil {
+ if p := strings.TrimSpace(auth.Attributes["path"]); p != "" {
+ authPath = p
+ }
+ }
+ if authPath == "" {
+ fileName := strings.TrimSpace(auth.FileName)
+ if fileName == "" {
+ return nil, fmt.Errorf("kiro executor: auth has no file path or filename for reload")
+ }
+ if filepath.IsAbs(fileName) {
+ authPath = fileName
+ } else if e.cfg != nil && e.cfg.AuthDir != "" {
+ authPath = filepath.Join(e.cfg.AuthDir, fileName)
+ } else {
+ return nil, fmt.Errorf("kiro executor: cannot determine auth file path for reload")
+ }
+ }
+
+ // 读取文件
+ raw, err := os.ReadFile(authPath)
+ if err != nil {
+ return nil, fmt.Errorf("kiro executor: failed to read auth file %s: %w", authPath, err)
+ }
+
+ // 解析 JSON
+ var metadata map[string]any
+ if err := json.Unmarshal(raw, &metadata); err != nil {
+ return nil, fmt.Errorf("kiro executor: failed to parse auth file %s: %w", authPath, err)
+ }
+
+ // 检查文件中的 token 是否比内存中的更新
+ fileExpiresAt, _ := metadata["expires_at"].(string)
+ fileAccessToken, _ := metadata["access_token"].(string)
+ memExpiresAt, _ := auth.Metadata["expires_at"].(string)
+ memAccessToken, _ := auth.Metadata["access_token"].(string)
+
+ // 文件中必须有有效的 access_token
+ if fileAccessToken == "" {
+ return nil, fmt.Errorf("kiro executor: auth file has no access_token field")
+ }
+
+ // 如果有 expires_at,检查是否过期
+ if fileExpiresAt != "" {
+ fileExpTime, parseErr := time.Parse(time.RFC3339, fileExpiresAt)
+ if parseErr == nil {
+ // 如果文件中的 token 也已过期,不使用它
+ if time.Now().After(fileExpTime) {
+ log.Debugf("kiro executor: file token also expired at %s, not using", fileExpiresAt)
+ return nil, fmt.Errorf("kiro executor: file token also expired")
+ }
+ }
+ }
+
+ // 判断文件中的 token 是否比内存中的更新
+ // 条件1: access_token 不同(说明已刷新)
+ // 条件2: expires_at 更新(说明已刷新)
+ isNewer := false
+
+ // 优先检查 access_token 是否变化
+ if fileAccessToken != memAccessToken {
+ isNewer = true
+ log.Debugf("kiro executor: file access_token differs from memory, using file token")
+ }
+
+ // 如果 access_token 相同,检查 expires_at
+ if !isNewer && fileExpiresAt != "" && memExpiresAt != "" {
+ fileExpTime, fileParseErr := time.Parse(time.RFC3339, fileExpiresAt)
+ memExpTime, memParseErr := time.Parse(time.RFC3339, memExpiresAt)
+ if fileParseErr == nil && memParseErr == nil && fileExpTime.After(memExpTime) {
+ isNewer = true
+ log.Debugf("kiro executor: file expires_at (%s) is newer than memory (%s)", fileExpiresAt, memExpiresAt)
+ }
+ }
+
+ // 如果文件中没有 expires_at 但 access_token 相同,无法判断是否更新
+ if !isNewer && fileExpiresAt == "" && fileAccessToken == memAccessToken {
+ return nil, fmt.Errorf("kiro executor: cannot determine if file token is newer (no expires_at, same access_token)")
+ }
+
+ if !isNewer {
+ log.Debugf("kiro executor: file token not newer than memory token")
+ return nil, fmt.Errorf("kiro executor: file token not newer")
+ }
+
+ // 创建更新后的 auth 对象
+ updated := auth.Clone()
+ updated.Metadata = metadata
+ updated.UpdatedAt = time.Now()
+
+ // 同步更新 Attributes
+ if updated.Attributes == nil {
+ updated.Attributes = make(map[string]string)
+ }
+ if accessToken, ok := metadata["access_token"].(string); ok {
+ updated.Attributes["access_token"] = accessToken
+ }
+ if profileArn, ok := metadata["profile_arn"].(string); ok {
+ updated.Attributes["profile_arn"] = profileArn
+ }
+
+ log.Infof("kiro executor: reloaded auth from file %s, new expires_at: %s", authPath, fileExpiresAt)
+ return updated, nil
+}
+
+// isTokenExpired checks if a JWT access token has expired.
+// Returns true if the token is expired or cannot be parsed.
+func (e *KiroExecutor) isTokenExpired(accessToken string) bool {
+ if accessToken == "" {
+ return true
+ }
+
+ // JWT tokens have 3 parts separated by dots
+ parts := strings.Split(accessToken, ".")
+ if len(parts) != 3 {
+ // Not a JWT token, assume not expired
+ return false
+ }
+
+ // Decode the payload (second part)
+ // JWT uses base64url encoding without padding (RawURLEncoding)
+ payload := parts[1]
+ decoded, err := base64.RawURLEncoding.DecodeString(payload)
+ if err != nil {
+ // Try with padding added as fallback
+ switch len(payload) % 4 {
+ case 2:
+ payload += "=="
+ case 3:
+ payload += "="
+ }
+ decoded, err = base64.URLEncoding.DecodeString(payload)
+ if err != nil {
+ log.Debugf("kiro: failed to decode JWT payload: %v", err)
+ return false
+ }
+ }
+
+ var claims struct {
+ Exp int64 `json:"exp"`
+ }
+ if err := json.Unmarshal(decoded, &claims); err != nil {
+ log.Debugf("kiro: failed to parse JWT claims: %v", err)
+ return false
+ }
+
+ if claims.Exp == 0 {
+ // No expiration claim, assume not expired
+ return false
+ }
+
+ expTime := time.Unix(claims.Exp, 0)
+ now := time.Now()
+
+ // Consider token expired if it expires within 1 minute (buffer for clock skew)
+ isExpired := now.After(expTime) || expTime.Sub(now) < time.Minute
+ if isExpired {
+ log.Debugf("kiro: token expired at %s (now: %s)", expTime.Format(time.RFC3339), now.Format(time.RFC3339))
+ }
+
+ return isExpired
+}
+
+// ══════════════════════════════════════════════════════════════════════════════
+// Web Search Handler (MCP API)
+// ══════════════════════════════════════════════════════════════════════════════
+
+// fetchToolDescription caching:
+// Uses a mutex + fetched flag to ensure only one goroutine fetches at a time,
+// with automatic retry on failure:
+// - On failure, fetched stays false so subsequent calls will retry
+// - On success, fetched is set to true — subsequent calls skip immediately (mutex-free fast path)
+// The cached description is stored in the translator package via kiroclaude.SetWebSearchDescription(),
+// enabling the translator's convertClaudeToolsToKiro to read it when building Kiro requests.
+var (
+ toolDescMu sync.Mutex
+ toolDescFetched atomic.Bool
+)
+
+// fetchToolDescription calls MCP tools/list to get the web_search tool description
+// and caches it. Safe to call concurrently — only one goroutine fetches at a time.
+// If the fetch fails, subsequent calls will retry. On success, no further fetches occur.
+// The httpClient parameter allows reusing a shared pooled HTTP client.
+func fetchToolDescription(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) {
+ // Fast path: already fetched successfully, no lock needed
+ if toolDescFetched.Load() {
+ return
+ }
+
+ toolDescMu.Lock()
+ defer toolDescMu.Unlock()
+
+ // Double-check after acquiring lock
+ if toolDescFetched.Load() {
+ return
+ }
+
+ handler := newWebSearchHandler(ctx, mcpEndpoint, authToken, httpClient, auth, authAttrs)
+ reqBody := []byte(`{"id":"tools_list","jsonrpc":"2.0","method":"tools/list"}`)
+ log.Debugf("kiro/websearch MCP tools/list request: %d bytes", len(reqBody))
+
+ req, err := http.NewRequestWithContext(ctx, "POST", mcpEndpoint, bytes.NewReader(reqBody))
+ if err != nil {
+ log.Warnf("kiro/websearch: failed to create tools/list request: %v", err)
+ return
+ }
+
+ // Reuse same headers as callMcpAPI
+ handler.setMcpHeaders(req)
+
+ resp, err := handler.httpClient.Do(req)
+ if err != nil {
+ log.Warnf("kiro/websearch: tools/list request failed: %v", err)
+ return
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil || resp.StatusCode != http.StatusOK {
+ log.Warnf("kiro/websearch: tools/list returned status %d", resp.StatusCode)
+ return
+ }
+ log.Debugf("kiro/websearch MCP tools/list response: [%d] %d bytes", resp.StatusCode, len(body))
+
+ // Parse: {"result":{"tools":[{"name":"web_search","description":"..."}]}}
+ var result struct {
+ Result *struct {
+ Tools []struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ } `json:"tools"`
+ } `json:"result"`
+ }
+ if err := json.Unmarshal(body, &result); err != nil || result.Result == nil {
+ log.Warnf("kiro/websearch: failed to parse tools/list response")
+ return
+ }
+
+ for _, tool := range result.Result.Tools {
+ if tool.Name == "web_search" && tool.Description != "" {
+ kiroclaude.SetWebSearchDescription(tool.Description)
+ toolDescFetched.Store(true) // success — no more fetches
+ log.Infof("kiro/websearch: cached web_search description from tools/list (%d bytes)", len(tool.Description))
+ return
+ }
+ }
+
+ // web_search tool not found in response
+ log.Warnf("kiro/websearch: web_search tool not found in tools/list response")
+}
+
+// webSearchHandler handles web search requests via Kiro MCP API
+type webSearchHandler struct {
+ ctx context.Context
+ mcpEndpoint string
+ httpClient *http.Client
+ authToken string
+ auth *cliproxyauth.Auth // for applyDynamicFingerprint
+ authAttrs map[string]string // optional, for custom headers from auth.Attributes
+}
+
+// newWebSearchHandler creates a new webSearchHandler.
+// If httpClient is nil, a default client with 30s timeout is used.
+// Pass a shared pooled client (e.g. from getKiroPooledHTTPClient) for connection reuse.
+func newWebSearchHandler(ctx context.Context, mcpEndpoint, authToken string, httpClient *http.Client, auth *cliproxyauth.Auth, authAttrs map[string]string) *webSearchHandler {
+ if httpClient == nil {
+ httpClient = &http.Client{
+ Timeout: 30 * time.Second,
+ }
+ }
+ return &webSearchHandler{
+ ctx: ctx,
+ mcpEndpoint: mcpEndpoint,
+ httpClient: httpClient,
+ authToken: authToken,
+ auth: auth,
+ authAttrs: authAttrs,
+ }
+}
+
+// setMcpHeaders sets standard MCP API headers on the request,
+// aligned with the GAR request pattern.
+func (h *webSearchHandler) setMcpHeaders(req *http.Request) {
+ // 1. Content-Type & Accept (aligned with GAR)
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "*/*")
+
+ // 2. Kiro-specific headers (aligned with GAR)
+ req.Header.Set("x-amzn-kiro-agent-mode", "vibe")
+ req.Header.Set("x-amzn-codewhisperer-optout", "true")
+
+ // 3. User-Agent: Reuse applyDynamicFingerprint for consistency
+ applyDynamicFingerprint(req, h.auth)
+
+ // 4. AWS SDK identifiers
+ req.Header.Set("Amz-Sdk-Request", "attempt=1; max=3")
+ req.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String())
+
+ // 5. Authentication
+ req.Header.Set("Authorization", "Bearer "+h.authToken)
+
+ // 6. Custom headers from auth attributes
+ util.ApplyCustomHeadersFromAttrs(req, h.authAttrs)
+}
+
+// mcpMaxRetries is the maximum number of retries for MCP API calls.
+const mcpMaxRetries = 2
+
+// callMcpAPI calls the Kiro MCP API with the given request.
+// Includes retry logic with exponential backoff for retryable errors.
+func (h *webSearchHandler) callMcpAPI(request *kiroclaude.McpRequest) (*kiroclaude.McpResponse, error) {
+ requestBody, err := json.Marshal(request)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal MCP request: %w", err)
+ }
+ log.Debugf("kiro/websearch MCP request → %s (%d bytes)", h.mcpEndpoint, len(requestBody))
+
+ var lastErr error
+ for attempt := 0; attempt <= mcpMaxRetries; attempt++ {
+ if attempt > 0 {
+ backoff := time.Duration(1< 10*time.Second {
+ backoff = 10 * time.Second
+ }
+ log.Warnf("kiro/websearch: MCP retry %d/%d after %v (last error: %v)", attempt, mcpMaxRetries, backoff, lastErr)
+ select {
+ case <-h.ctx.Done():
+ return nil, h.ctx.Err()
+ case <-time.After(backoff):
+ }
+ }
+
+ req, err := http.NewRequestWithContext(h.ctx, "POST", h.mcpEndpoint, bytes.NewReader(requestBody))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create HTTP request: %w", err)
+ }
+
+ h.setMcpHeaders(req)
+
+ resp, err := h.httpClient.Do(req)
+ if err != nil {
+ lastErr = fmt.Errorf("MCP API request failed: %w", err)
+ continue // network error → retry
+ }
+
+ body, err := io.ReadAll(resp.Body)
+ resp.Body.Close()
+ if err != nil {
+ lastErr = fmt.Errorf("failed to read MCP response: %w", err)
+ continue // read error → retry
+ }
+ log.Debugf("kiro/websearch MCP response ← [%d] (%d bytes)", resp.StatusCode, len(body))
+
+ // Retryable HTTP status codes (aligned with GAR: 502, 503, 504)
+ if resp.StatusCode >= 502 && resp.StatusCode <= 504 {
+ lastErr = fmt.Errorf("MCP API returned retryable status %d: %s", resp.StatusCode, string(body))
+ continue
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("MCP API returned status %d: %s", resp.StatusCode, string(body))
+ }
+
+ var mcpResponse kiroclaude.McpResponse
+ if err := json.Unmarshal(body, &mcpResponse); err != nil {
+ return nil, fmt.Errorf("failed to parse MCP response: %w", err)
+ }
+
+ if mcpResponse.Error != nil {
+ code := -1
+ if mcpResponse.Error.Code != nil {
+ code = *mcpResponse.Error.Code
+ }
+ msg := "Unknown error"
+ if mcpResponse.Error.Message != nil {
+ msg = *mcpResponse.Error.Message
+ }
+ return nil, fmt.Errorf("MCP error %d: %s", code, msg)
+ }
+
+ return &mcpResponse, nil
+ }
+
+ return nil, lastErr
+}
+
+// webSearchAuthAttrs extracts auth attributes for MCP calls.
+// Used by handleWebSearch and handleWebSearchStream to pass custom headers.
+func webSearchAuthAttrs(auth *cliproxyauth.Auth) map[string]string {
+ if auth != nil {
+ return auth.Attributes
+ }
+ return nil
+}
+
+const maxWebSearchIterations = 5
+
+// handleWebSearchStream handles web_search requests:
+// Step 1: tools/list (sync) → fetch/cache tool description
+// Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop
+// Note: We skip the "model decides to search" step because Claude Code already
+// decided to use web_search. The Kiro tool description restricts non-coding
+// topics, so asking the model again would cause it to refuse valid searches.
+func (e *KiroExecutor) handleWebSearchStream(
+ ctx context.Context,
+ auth *cliproxyauth.Auth,
+ req cliproxyexecutor.Request,
+ opts cliproxyexecutor.Options,
+ accessToken, profileArn string,
+) (<-chan cliproxyexecutor.StreamChunk, error) {
+ // Extract search query from Claude Code's web_search tool_use
+ query := kiroclaude.ExtractSearchQuery(req.Payload)
+ if query == "" {
+ log.Warnf("kiro/websearch: failed to extract search query, falling back to normal flow")
+ return e.callKiroDirectStream(ctx, auth, req, opts, accessToken, profileArn)
+ }
+
+ // Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback)
+ region := resolveKiroAPIRegion(auth)
+ mcpEndpoint := kiroclaude.BuildMcpEndpoint(region)
+
+ // ── Step 1: tools/list (SYNC) — cache tool description ──
+ {
+ authAttrs := webSearchAuthAttrs(auth)
+ fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
+ }
+
+ // Create output channel
+ out := make(chan cliproxyexecutor.StreamChunk)
+
+ // Usage reporting: track web search requests like normal streaming requests
+ reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
+
+ go func() {
+ var wsErr error
+ defer reporter.trackFailure(ctx, &wsErr)
+ defer close(out)
+
+ // Estimate input tokens using tokenizer (matching streamToChannel pattern)
+ var totalUsage usage.Detail
+ if enc, tokErr := getTokenizer(req.Model); tokErr == nil {
+ if inp, e := countClaudeChatTokens(enc, req.Payload); e == nil && inp > 0 {
+ totalUsage.InputTokens = inp
+ } else {
+ totalUsage.InputTokens = int64(len(req.Payload) / 4)
+ }
+ } else {
+ totalUsage.InputTokens = int64(len(req.Payload) / 4)
+ }
+ if totalUsage.InputTokens == 0 && len(req.Payload) > 0 {
+ totalUsage.InputTokens = 1
+ }
+ var accumulatedOutputLen int
+ defer func() {
+ if wsErr != nil {
+ return // let trackFailure handle failure reporting
+ }
+ totalUsage.OutputTokens = int64(accumulatedOutputLen / 4)
+ if accumulatedOutputLen > 0 && totalUsage.OutputTokens == 0 {
+ totalUsage.OutputTokens = 1
+ }
+ reporter.publish(ctx, totalUsage)
+ }()
+
+ // Send message_start event to client (aligned with streamToChannel pattern)
+ // Use payloadRequestedModel to return user's original model alias
+ msgStart := kiroclaude.BuildClaudeMessageStartEvent(
+ payloadRequestedModel(opts, req.Model),
+ totalUsage.InputTokens,
+ )
+ select {
+ case <-ctx.Done():
+ return
+ case out <- cliproxyexecutor.StreamChunk{Payload: append(msgStart, '\n', '\n')}:
+ }
+
+ // ── Step 2+: MCP search → InjectToolResultsClaude → callKiroAndBuffer loop ──
+ contentBlockIndex := 0
+ currentQuery := query
+
+ // Replace web_search tool description with a minimal one that allows re-search.
+ // The original tools/list description from Kiro restricts non-coding topics,
+ // but we've already decided to search. We keep the tool so the model can
+ // request additional searches when results are insufficient.
+ simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload))
+ if simplifyErr != nil {
+ log.Warnf("kiro/websearch: failed to simplify web_search tool: %v, using original payload", simplifyErr)
+ simplifiedPayload = bytes.Clone(req.Payload)
+ }
+
+ currentClaudePayload := simplifiedPayload
+ totalSearches := 0
+
+ // Generate toolUseId for the first iteration (Claude Code already decided to search)
+ currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
+
+ for iteration := 0; iteration < maxWebSearchIterations; iteration++ {
+ log.Infof("kiro/websearch: search iteration %d/%d",
+ iteration+1, maxWebSearchIterations)
+
+ // MCP search
+ _, mcpRequest := kiroclaude.CreateMcpRequest(currentQuery)
+
+ authAttrs := webSearchAuthAttrs(auth)
+ handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
+ mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest)
+
+ var searchResults *kiroclaude.WebSearchResults
+ if mcpErr != nil {
+ log.Warnf("kiro/websearch: MCP API call failed: %v, continuing with empty results", mcpErr)
+ } else {
+ searchResults = kiroclaude.ParseSearchResults(mcpResponse)
+ }
+
+ resultCount := 0
+ if searchResults != nil {
+ resultCount = len(searchResults.Results)
+ }
+ totalSearches++
+ log.Infof("kiro/websearch: iteration %d — got %d search results", iteration+1, resultCount)
+
+ // Send search indicator events to client
+ searchEvents := kiroclaude.GenerateSearchIndicatorEvents(currentQuery, currentToolUseId, searchResults, contentBlockIndex)
+ for _, event := range searchEvents {
+ select {
+ case <-ctx.Done():
+ return
+ case out <- cliproxyexecutor.StreamChunk{Payload: event}:
+ }
+ }
+ contentBlockIndex += 2
+
+ // Inject tool_use + tool_result into Claude payload, then call GAR
+ var err error
+ currentClaudePayload, err = kiroclaude.InjectToolResultsClaude(currentClaudePayload, currentToolUseId, currentQuery, searchResults)
+ if err != nil {
+ log.Warnf("kiro/websearch: failed to inject tool results: %v", err)
+ wsErr = fmt.Errorf("failed to inject tool results: %w", err)
+ e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults)
+ return
+ }
+
+ // Call GAR with modified Claude payload (full translation pipeline)
+ modifiedReq := req
+ modifiedReq.Payload = currentClaudePayload
+ kiroChunks, kiroErr := e.callKiroAndBuffer(ctx, auth, modifiedReq, opts, accessToken, profileArn)
+ if kiroErr != nil {
+ log.Warnf("kiro/websearch: Kiro API failed at iteration %d: %v", iteration+1, kiroErr)
+ wsErr = fmt.Errorf("Kiro API failed at iteration %d: %w", iteration+1, kiroErr)
+ e.sendFallbackText(ctx, out, contentBlockIndex, currentQuery, searchResults)
+ return
+ }
+
+ // Analyze response
+ analysis := kiroclaude.AnalyzeBufferedStream(kiroChunks)
+ log.Infof("kiro/websearch: iteration %d — stop_reason: %s, has_tool_use: %v",
+ iteration+1, analysis.StopReason, analysis.HasWebSearchToolUse)
+
+ if analysis.HasWebSearchToolUse && analysis.WebSearchQuery != "" && iteration+1 < maxWebSearchIterations {
+ // Model wants another search
+ filteredChunks := kiroclaude.FilterChunksForClient(kiroChunks, analysis.WebSearchToolUseIndex, contentBlockIndex)
+ for _, chunk := range filteredChunks {
+ select {
+ case <-ctx.Done():
+ return
+ case out <- cliproxyexecutor.StreamChunk{Payload: chunk}:
+ }
+ }
+
+ currentQuery = analysis.WebSearchQuery
+ currentToolUseId = analysis.WebSearchToolUseId
+ continue
+ }
+
+ // Model returned final response — stream to client
+ for _, chunk := range kiroChunks {
+ if contentBlockIndex > 0 && len(chunk) > 0 {
+ adjusted, shouldForward := kiroclaude.AdjustSSEChunk(chunk, contentBlockIndex)
+ if !shouldForward {
+ continue
+ }
+ accumulatedOutputLen += len(adjusted)
+ select {
+ case <-ctx.Done():
+ return
+ case out <- cliproxyexecutor.StreamChunk{Payload: adjusted}:
+ }
+ } else {
+ accumulatedOutputLen += len(chunk)
+ select {
+ case <-ctx.Done():
+ return
+ case out <- cliproxyexecutor.StreamChunk{Payload: chunk}:
+ }
+ }
+ }
+ log.Infof("kiro/websearch: completed after %d search iteration(s), total searches: %d", iteration+1, totalSearches)
+ return
+ }
+
+ log.Warnf("kiro/websearch: reached max iterations (%d), stopping search loop", maxWebSearchIterations)
+ }()
+
+ return out, nil
+}
+
+// handleWebSearch handles web_search requests for non-streaming Execute path.
+// Performs MCP search synchronously, injects results into the request payload,
+// then calls the normal non-streaming Kiro API path which returns a proper
+// Claude JSON response (not SSE chunks).
+func (e *KiroExecutor) handleWebSearch(
+ ctx context.Context,
+ auth *cliproxyauth.Auth,
+ req cliproxyexecutor.Request,
+ opts cliproxyexecutor.Options,
+ accessToken, profileArn string,
+) (cliproxyexecutor.Response, error) {
+ // Extract search query from Claude Code's web_search tool_use
+ query := kiroclaude.ExtractSearchQuery(req.Payload)
+ if query == "" {
+ log.Warnf("kiro/websearch: non-stream: failed to extract search query, falling back to normal Execute")
+ // Fall through to normal non-streaming path
+ return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
+ }
+
+ // Build MCP endpoint using shared region resolution (supports api_region + ProfileARN fallback)
+ region := resolveKiroAPIRegion(auth)
+ mcpEndpoint := kiroclaude.BuildMcpEndpoint(region)
+
+ // Step 1: Fetch/cache tool description (sync)
+ {
+ authAttrs := webSearchAuthAttrs(auth)
+ fetchToolDescription(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
+ }
+
+ // Step 2: Perform MCP search
+ _, mcpRequest := kiroclaude.CreateMcpRequest(query)
+
+ authAttrs := webSearchAuthAttrs(auth)
+ handler := newWebSearchHandler(ctx, mcpEndpoint, accessToken, newKiroHTTPClientWithPooling(ctx, e.cfg, auth, 30*time.Second), auth, authAttrs)
+ mcpResponse, mcpErr := handler.callMcpAPI(mcpRequest)
+
+ var searchResults *kiroclaude.WebSearchResults
+ if mcpErr != nil {
+ log.Warnf("kiro/websearch: non-stream: MCP API call failed: %v, continuing with empty results", mcpErr)
+ } else {
+ searchResults = kiroclaude.ParseSearchResults(mcpResponse)
+ }
+
+ resultCount := 0
+ if searchResults != nil {
+ resultCount = len(searchResults.Results)
+ }
+ log.Infof("kiro/websearch: non-stream: got %d search results", resultCount)
+
+ // Step 3: Replace restrictive web_search tool description (align with streaming path)
+ simplifiedPayload, simplifyErr := kiroclaude.ReplaceWebSearchToolDescription(bytes.Clone(req.Payload))
+ if simplifyErr != nil {
+ log.Warnf("kiro/websearch: non-stream: failed to simplify web_search tool: %v, using original payload", simplifyErr)
+ simplifiedPayload = bytes.Clone(req.Payload)
+ }
+
+ // Step 4: Inject search tool_use + tool_result into Claude payload
+ currentToolUseId := fmt.Sprintf("srvtoolu_%s", kiroclaude.GenerateToolUseID())
+ modifiedPayload, err := kiroclaude.InjectToolResultsClaude(simplifiedPayload, currentToolUseId, query, searchResults)
+ if err != nil {
+ log.Warnf("kiro/websearch: non-stream: failed to inject tool results: %v, falling back", err)
+ return e.executeNonStreamFallback(ctx, auth, req, opts, accessToken, profileArn)
+ }
+
+ // Step 5: Call Kiro API via the normal non-streaming path (executeWithRetry)
+ // This path uses parseEventStream → BuildClaudeResponse → TranslateNonStream
+ // to produce a proper Claude JSON response
+ modifiedReq := req
+ modifiedReq.Payload = modifiedPayload
+
+ resp, err := e.executeNonStreamFallback(ctx, auth, modifiedReq, opts, accessToken, profileArn)
+ if err != nil {
+ return resp, err
+ }
+
+ // Step 6: Inject server_tool_use + web_search_tool_result into response
+ // so Claude Code can display "Did X searches in Ys"
+ indicators := []kiroclaude.SearchIndicator{
+ {
+ ToolUseID: currentToolUseId,
+ Query: query,
+ Results: searchResults,
+ },
+ }
+ injectedPayload, injErr := kiroclaude.InjectSearchIndicatorsInResponse(resp.Payload, indicators)
+ if injErr != nil {
+ log.Warnf("kiro/websearch: non-stream: failed to inject search indicators: %v", injErr)
+ } else {
+ resp.Payload = injectedPayload
+ }
+
+ return resp, nil
+}
+
+// callKiroAndBuffer calls the Kiro API and buffers all response chunks.
+// Returns the buffered chunks for analysis before forwarding to client.
+// Usage reporting is NOT done here — the caller (handleWebSearchStream) manages its own reporter.
+func (e *KiroExecutor) callKiroAndBuffer(
+ ctx context.Context,
+ auth *cliproxyauth.Auth,
+ req cliproxyexecutor.Request,
+ opts cliproxyexecutor.Options,
+ accessToken, profileArn string,
+) ([][]byte, error) {
+ from := opts.SourceFormat
+ to := sdktranslator.FromString("kiro")
+ body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
+ log.Debugf("kiro/websearch GAR request: %d bytes", len(body))
+
+ kiroModelID := e.mapModelToKiro(req.Model)
+ isAgentic, isChatOnly := determineAgenticMode(req.Model)
+ effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
+
+ tokenKey := getTokenKey(auth)
+
+ kiroStream, err := e.executeStreamWithRetry(
+ ctx, auth, req, opts, accessToken, effectiveProfileArn,
+ nil, body, from, nil, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // Buffer all chunks
+ var chunks [][]byte
+ for chunk := range kiroStream {
+ if chunk.Err != nil {
+ return chunks, chunk.Err
+ }
+ if len(chunk.Payload) > 0 {
+ chunks = append(chunks, bytes.Clone(chunk.Payload))
+ }
+ }
+
+ log.Debugf("kiro/websearch GAR response: %d chunks buffered", len(chunks))
+
+ return chunks, nil
+}
+
+// callKiroDirectStream creates a direct streaming channel to Kiro API without search.
+func (e *KiroExecutor) callKiroDirectStream(
+ ctx context.Context,
+ auth *cliproxyauth.Auth,
+ req cliproxyexecutor.Request,
+ opts cliproxyexecutor.Options,
+ accessToken, profileArn string,
+) (<-chan cliproxyexecutor.StreamChunk, error) {
+ from := opts.SourceFormat
+ to := sdktranslator.FromString("kiro")
+ body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
+
+ kiroModelID := e.mapModelToKiro(req.Model)
+ isAgentic, isChatOnly := determineAgenticMode(req.Model)
+ effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
+
+ tokenKey := getTokenKey(auth)
+
+ reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
+ var streamErr error
+ defer reporter.trackFailure(ctx, &streamErr)
+
+ stream, streamErr := e.executeStreamWithRetry(
+ ctx, auth, req, opts, accessToken, effectiveProfileArn,
+ nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey,
+ )
+ return stream, streamErr
+}
+
+// sendFallbackText sends a simple text response when the Kiro API fails during the search loop.
+// Delegates SSE event construction to kiroclaude.BuildFallbackTextEvents() for alignment
+// with how streamToChannel() uses BuildClaude*Event() functions.
+func (e *KiroExecutor) sendFallbackText(
+ ctx context.Context,
+ out chan<- cliproxyexecutor.StreamChunk,
+ contentBlockIndex int,
+ query string,
+ searchResults *kiroclaude.WebSearchResults,
+) {
+ events := kiroclaude.BuildFallbackTextEvents(contentBlockIndex, query, searchResults)
+ for _, event := range events {
+ select {
+ case <-ctx.Done():
+ return
+ case out <- cliproxyexecutor.StreamChunk{Payload: append(event, '\n', '\n')}:
+ }
+ }
+}
+
+// executeNonStreamFallback runs the standard non-streaming Execute path for a request.
+// Used by handleWebSearch after injecting search results, or as a fallback.
+func (e *KiroExecutor) executeNonStreamFallback(
+ ctx context.Context,
+ auth *cliproxyauth.Auth,
+ req cliproxyexecutor.Request,
+ opts cliproxyexecutor.Options,
+ accessToken, profileArn string,
+) (cliproxyexecutor.Response, error) {
+ from := opts.SourceFormat
+ to := sdktranslator.FromString("kiro")
+ body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
+
+ kiroModelID := e.mapModelToKiro(req.Model)
+ isAgentic, isChatOnly := determineAgenticMode(req.Model)
+ effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn)
+ tokenKey := getTokenKey(auth)
+
+ reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
+ var err error
+ defer reporter.trackFailure(ctx, &err)
+
+ resp, err := e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, to, reporter, "", kiroModelID, isAgentic, isChatOnly, tokenKey)
+ return resp, err
+}
diff --git a/internal/runtime/executor/proxy_helpers.go b/internal/runtime/executor/proxy_helpers.go
index ab0f626a..8998eb23 100644
--- a/internal/runtime/executor/proxy_helpers.go
+++ b/internal/runtime/executor/proxy_helpers.go
@@ -6,6 +6,7 @@ import (
"net/http"
"net/url"
"strings"
+ "sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -14,11 +15,19 @@ import (
"golang.org/x/net/proxy"
)
+// httpClientCache caches HTTP clients by proxy URL to enable connection reuse
+var (
+ httpClientCache = make(map[string]*http.Client)
+ httpClientCacheMutex sync.RWMutex
+)
+
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
// 1. Use auth.ProxyURL if configured (highest priority)
// 2. Use cfg.ProxyURL if auth proxy is not configured
// 3. Use RoundTripper from context if neither are configured
//
+// This function caches HTTP clients by proxy URL to enable TCP/TLS connection reuse.
+//
// Parameters:
// - ctx: The context containing optional RoundTripper
// - cfg: The application configuration
@@ -28,11 +37,6 @@ import (
// Returns:
// - *http.Client: An HTTP client with configured proxy or transport
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
- httpClient := &http.Client{}
- if timeout > 0 {
- httpClient.Timeout = timeout
- }
-
// Priority 1: Use auth.ProxyURL if configured
var proxyURL string
if auth != nil {
@@ -44,11 +48,39 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
proxyURL = strings.TrimSpace(cfg.ProxyURL)
}
+ // Build cache key from proxy URL (empty string for no proxy)
+ cacheKey := proxyURL
+
+ // Check cache first
+ httpClientCacheMutex.RLock()
+ if cachedClient, ok := httpClientCache[cacheKey]; ok {
+ httpClientCacheMutex.RUnlock()
+ // Return a wrapper with the requested timeout but shared transport
+ if timeout > 0 {
+ return &http.Client{
+ Transport: cachedClient.Transport,
+ Timeout: timeout,
+ }
+ }
+ return cachedClient
+ }
+ httpClientCacheMutex.RUnlock()
+
+ // Create new client
+ httpClient := &http.Client{}
+ if timeout > 0 {
+ httpClient.Timeout = timeout
+ }
+
// If we have a proxy URL configured, set up the transport
if proxyURL != "" {
transport := buildProxyTransport(proxyURL)
if transport != nil {
httpClient.Transport = transport
+ // Cache the client
+ httpClientCacheMutex.Lock()
+ httpClientCache[cacheKey] = httpClient
+ httpClientCacheMutex.Unlock()
return httpClient
}
// If proxy setup failed, log and fall through to context RoundTripper
@@ -60,6 +92,13 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
httpClient.Transport = rt
}
+ // Cache the client for no-proxy case
+ if proxyURL == "" {
+ httpClientCacheMutex.Lock()
+ httpClientCache[cacheKey] = httpClient
+ httpClientCacheMutex.Unlock()
+ }
+
return httpClient
}
diff --git a/internal/runtime/executor/token_helpers.go b/internal/runtime/executor/token_helpers.go
index f4236f9b..54188599 100644
--- a/internal/runtime/executor/token_helpers.go
+++ b/internal/runtime/executor/token_helpers.go
@@ -2,43 +2,109 @@ package executor
import (
"fmt"
+ "regexp"
+ "strconv"
"strings"
+ "sync"
"github.com/tidwall/gjson"
"github.com/tiktoken-go/tokenizer"
)
+// tokenizerCache stores tokenizer instances to avoid repeated creation
+var tokenizerCache sync.Map
+
+// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models
+// where tiktoken may not accurately estimate token counts (e.g., Claude models)
+type TokenizerWrapper struct {
+ Codec tokenizer.Codec
+ AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates
+}
+
+// Count returns the token count with adjustment factor applied
+func (tw *TokenizerWrapper) Count(text string) (int, error) {
+ count, err := tw.Codec.Count(text)
+ if err != nil {
+ return 0, err
+ }
+ if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 {
+ return int(float64(count) * tw.AdjustmentFactor), nil
+ }
+ return count, nil
+}
+
+// getTokenizer returns a cached tokenizer for the given model.
+// This improves performance by avoiding repeated tokenizer creation.
+func getTokenizer(model string) (*TokenizerWrapper, error) {
+ // Check cache first
+ if cached, ok := tokenizerCache.Load(model); ok {
+ return cached.(*TokenizerWrapper), nil
+ }
+
+ // Cache miss, create new tokenizer
+ wrapper, err := tokenizerForModel(model)
+ if err != nil {
+ return nil, err
+ }
+
+ // Store in cache (use LoadOrStore to handle race conditions)
+ actual, _ := tokenizerCache.LoadOrStore(model, wrapper)
+ return actual.(*TokenizerWrapper), nil
+}
+
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
-func tokenizerForModel(model string) (tokenizer.Codec, error) {
+// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate.
+func tokenizerForModel(model string) (*TokenizerWrapper, error) {
sanitized := strings.ToLower(strings.TrimSpace(model))
+
+ // Claude models use cl100k_base with 1.1 adjustment factor
+ // because tiktoken may underestimate Claude's actual token count
+ if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
+ enc, err := tokenizer.Get(tokenizer.Cl100kBase)
+ if err != nil {
+ return nil, err
+ }
+ return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil
+ }
+
+ var enc tokenizer.Codec
+ var err error
+
switch {
case sanitized == "":
- return tokenizer.Get(tokenizer.Cl100kBase)
- case strings.HasPrefix(sanitized, "gpt-5"):
- return tokenizer.ForModel(tokenizer.GPT5)
+ enc, err = tokenizer.Get(tokenizer.Cl100kBase)
+ case strings.HasPrefix(sanitized, "gpt-5.2"):
+ enc, err = tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-5.1"):
- return tokenizer.ForModel(tokenizer.GPT5)
+ enc, err = tokenizer.ForModel(tokenizer.GPT5)
+ case strings.HasPrefix(sanitized, "gpt-5"):
+ enc, err = tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-4.1"):
- return tokenizer.ForModel(tokenizer.GPT41)
+ enc, err = tokenizer.ForModel(tokenizer.GPT41)
case strings.HasPrefix(sanitized, "gpt-4o"):
- return tokenizer.ForModel(tokenizer.GPT4o)
+ enc, err = tokenizer.ForModel(tokenizer.GPT4o)
case strings.HasPrefix(sanitized, "gpt-4"):
- return tokenizer.ForModel(tokenizer.GPT4)
+ enc, err = tokenizer.ForModel(tokenizer.GPT4)
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
- return tokenizer.ForModel(tokenizer.GPT35Turbo)
+ enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo)
case strings.HasPrefix(sanitized, "o1"):
- return tokenizer.ForModel(tokenizer.O1)
+ enc, err = tokenizer.ForModel(tokenizer.O1)
case strings.HasPrefix(sanitized, "o3"):
- return tokenizer.ForModel(tokenizer.O3)
+ enc, err = tokenizer.ForModel(tokenizer.O3)
case strings.HasPrefix(sanitized, "o4"):
- return tokenizer.ForModel(tokenizer.O4Mini)
+ enc, err = tokenizer.ForModel(tokenizer.O4Mini)
default:
- return tokenizer.Get(tokenizer.O200kBase)
+ enc, err = tokenizer.Get(tokenizer.O200kBase)
}
+
+ if err != nil {
+ return nil, err
+ }
+ return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil
}
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
-func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
+func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
if enc == nil {
return 0, fmt.Errorf("encoder is nil")
}
@@ -62,11 +128,206 @@ func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
return 0, nil
}
+ // Count text tokens
count, err := enc.Count(joined)
if err != nil {
return 0, err
}
- return int64(count), nil
+
+ // Extract and add image tokens from placeholders
+ imageTokens := extractImageTokens(joined)
+
+ return int64(count) + int64(imageTokens), nil
+}
+
+// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads.
+// This handles Claude's message format with system, messages, and tools.
+// Image tokens are estimated based on image dimensions when available.
+func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
+ if enc == nil {
+ return 0, fmt.Errorf("encoder is nil")
+ }
+ if len(payload) == 0 {
+ return 0, nil
+ }
+
+ root := gjson.ParseBytes(payload)
+ segments := make([]string, 0, 32)
+
+ // Collect system prompt (can be string or array of content blocks)
+ collectClaudeSystem(root.Get("system"), &segments)
+
+ // Collect messages
+ collectClaudeMessages(root.Get("messages"), &segments)
+
+ // Collect tools
+ collectClaudeTools(root.Get("tools"), &segments)
+
+ joined := strings.TrimSpace(strings.Join(segments, "\n"))
+ if joined == "" {
+ return 0, nil
+ }
+
+ // Count text tokens
+ count, err := enc.Count(joined)
+ if err != nil {
+ return 0, err
+ }
+
+ // Extract and add image tokens from placeholders
+ imageTokens := extractImageTokens(joined)
+
+ return int64(count) + int64(imageTokens), nil
+}
+
+// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens
+var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`)
+
+// extractImageTokens extracts image token estimates from placeholder text.
+// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count.
+func extractImageTokens(text string) int {
+ matches := imageTokenPattern.FindAllStringSubmatch(text, -1)
+ total := 0
+ for _, match := range matches {
+ if len(match) > 1 {
+ if tokens, err := strconv.Atoi(match[1]); err == nil {
+ total += tokens
+ }
+ }
+ }
+ return total
+}
+
+// estimateImageTokens calculates estimated tokens for an image based on dimensions.
+// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
+// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
+func estimateImageTokens(width, height float64) int {
+ if width <= 0 || height <= 0 {
+ // No valid dimensions, use default estimate (medium-sized image)
+ return 1000
+ }
+
+ tokens := int(width * height / 750)
+
+ // Apply bounds
+ if tokens < 85 {
+ tokens = 85
+ }
+ if tokens > 1590 {
+ tokens = 1590
+ }
+
+ return tokens
+}
+
+// collectClaudeSystem extracts text from Claude's system field.
+// System can be a string or an array of content blocks.
+func collectClaudeSystem(system gjson.Result, segments *[]string) {
+ if !system.Exists() {
+ return
+ }
+ if system.Type == gjson.String {
+ addIfNotEmpty(segments, system.String())
+ return
+ }
+ if system.IsArray() {
+ system.ForEach(func(_, block gjson.Result) bool {
+ blockType := block.Get("type").String()
+ if blockType == "text" || blockType == "" {
+ addIfNotEmpty(segments, block.Get("text").String())
+ }
+ // Also handle plain string blocks
+ if block.Type == gjson.String {
+ addIfNotEmpty(segments, block.String())
+ }
+ return true
+ })
+ }
+}
+
+// collectClaudeMessages extracts text from Claude's messages array.
+func collectClaudeMessages(messages gjson.Result, segments *[]string) {
+ if !messages.Exists() || !messages.IsArray() {
+ return
+ }
+ messages.ForEach(func(_, message gjson.Result) bool {
+ addIfNotEmpty(segments, message.Get("role").String())
+ collectClaudeContent(message.Get("content"), segments)
+ return true
+ })
+}
+
+// collectClaudeContent extracts text from Claude's content field.
+// Content can be a string or an array of content blocks.
+// For images, estimates token count based on dimensions when available.
+func collectClaudeContent(content gjson.Result, segments *[]string) {
+ if !content.Exists() {
+ return
+ }
+ if content.Type == gjson.String {
+ addIfNotEmpty(segments, content.String())
+ return
+ }
+ if content.IsArray() {
+ content.ForEach(func(_, part gjson.Result) bool {
+ partType := part.Get("type").String()
+ switch partType {
+ case "text":
+ addIfNotEmpty(segments, part.Get("text").String())
+ case "image":
+ // Estimate image tokens based on dimensions if available
+ source := part.Get("source")
+ if source.Exists() {
+ width := source.Get("width").Float()
+ height := source.Get("height").Float()
+ if width > 0 && height > 0 {
+ tokens := estimateImageTokens(width, height)
+ addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens))
+ } else {
+ // No dimensions available, use default estimate
+ addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
+ }
+ } else {
+ // No source info, use default estimate
+ addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
+ }
+ case "tool_use":
+ addIfNotEmpty(segments, part.Get("id").String())
+ addIfNotEmpty(segments, part.Get("name").String())
+ if input := part.Get("input"); input.Exists() {
+ addIfNotEmpty(segments, input.Raw)
+ }
+ case "tool_result":
+ addIfNotEmpty(segments, part.Get("tool_use_id").String())
+ collectClaudeContent(part.Get("content"), segments)
+ case "thinking":
+ addIfNotEmpty(segments, part.Get("thinking").String())
+ default:
+ // For unknown types, try to extract any text content
+ if part.Type == gjson.String {
+ addIfNotEmpty(segments, part.String())
+ } else if part.Type == gjson.JSON {
+ addIfNotEmpty(segments, part.Raw)
+ }
+ }
+ return true
+ })
+ }
+}
+
+// collectClaudeTools extracts text from Claude's tools array.
+func collectClaudeTools(tools gjson.Result, segments *[]string) {
+ if !tools.Exists() || !tools.IsArray() {
+ return
+ }
+ tools.ForEach(func(_, tool gjson.Result) bool {
+ addIfNotEmpty(segments, tool.Get("name").String())
+ addIfNotEmpty(segments, tool.Get("description").String())
+ if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
+ addIfNotEmpty(segments, inputSchema.Raw)
+ }
+ return true
+ })
}
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
diff --git a/internal/runtime/executor/usage_helpers.go b/internal/runtime/executor/usage_helpers.go
index 00f547df..a642fac2 100644
--- a/internal/runtime/executor/usage_helpers.go
+++ b/internal/runtime/executor/usage_helpers.go
@@ -252,6 +252,44 @@ func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
return detail, true
}
+func parseOpenAIResponsesUsageDetail(usageNode gjson.Result) usage.Detail {
+ detail := usage.Detail{
+ InputTokens: usageNode.Get("input_tokens").Int(),
+ OutputTokens: usageNode.Get("output_tokens").Int(),
+ TotalTokens: usageNode.Get("total_tokens").Int(),
+ }
+ if detail.TotalTokens == 0 {
+ detail.TotalTokens = detail.InputTokens + detail.OutputTokens
+ }
+ if cached := usageNode.Get("input_tokens_details.cached_tokens"); cached.Exists() {
+ detail.CachedTokens = cached.Int()
+ }
+ if reasoning := usageNode.Get("output_tokens_details.reasoning_tokens"); reasoning.Exists() {
+ detail.ReasoningTokens = reasoning.Int()
+ }
+ return detail
+}
+
+func parseOpenAIResponsesUsage(data []byte) usage.Detail {
+ usageNode := gjson.ParseBytes(data).Get("usage")
+ if !usageNode.Exists() {
+ return usage.Detail{}
+ }
+ return parseOpenAIResponsesUsageDetail(usageNode)
+}
+
+func parseOpenAIResponsesStreamUsage(line []byte) (usage.Detail, bool) {
+ payload := jsonPayload(line)
+ if len(payload) == 0 || !gjson.ValidBytes(payload) {
+ return usage.Detail{}, false
+ }
+ usageNode := gjson.GetBytes(payload, "usage")
+ if !usageNode.Exists() {
+ return usage.Detail{}, false
+ }
+ return parseOpenAIResponsesUsageDetail(usageNode), true
+}
+
func parseClaudeUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() {
diff --git a/internal/translator/claude/openai/chat-completions/claude_openai_response.go b/internal/translator/claude/openai/chat-completions/claude_openai_response.go
index 0ddfeaec..346db69a 100644
--- a/internal/translator/claude/openai/chat-completions/claude_openai_response.go
+++ b/internal/translator/claude/openai/chat-completions/claude_openai_response.go
@@ -50,6 +50,10 @@ type ToolCallAccumulator struct {
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
+ var localParam any
+ if param == nil {
+ param = &localParam
+ }
if *param == nil {
*param = &ConvertAnthropicResponseToOpenAIParams{
CreatedAt: 0,
diff --git a/internal/translator/init.go b/internal/translator/init.go
index 084ea7ac..0754db03 100644
--- a/internal/translator/init.go
+++ b/internal/translator/init.go
@@ -33,4 +33,7 @@ import (
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/chat-completions"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/responses"
+
+ _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude"
+ _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai"
)
diff --git a/internal/translator/kiro/claude/init.go b/internal/translator/kiro/claude/init.go
new file mode 100644
index 00000000..1685d195
--- /dev/null
+++ b/internal/translator/kiro/claude/init.go
@@ -0,0 +1,20 @@
+// Package claude provides translation between Kiro and Claude formats.
+package claude
+
+import (
+ . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
+)
+
+func init() {
+ translator.Register(
+ Claude,
+ Kiro,
+ ConvertClaudeRequestToKiro,
+ interfaces.TranslateResponse{
+ Stream: ConvertKiroStreamToClaude,
+ NonStream: ConvertKiroNonStreamToClaude,
+ },
+ )
+}
diff --git a/internal/translator/kiro/claude/kiro_claude.go b/internal/translator/kiro/claude/kiro_claude.go
new file mode 100644
index 00000000..752a00d9
--- /dev/null
+++ b/internal/translator/kiro/claude/kiro_claude.go
@@ -0,0 +1,21 @@
+// Package claude provides translation between Kiro and Claude formats.
+// Since Kiro executor generates Claude-compatible SSE format internally (with event: prefix),
+// translations are pass-through for streaming, but responses need proper formatting.
+package claude
+
+import (
+ "context"
+)
+
+// ConvertKiroStreamToClaude converts Kiro streaming response to Claude format.
+// Kiro executor already generates complete SSE format with "event:" prefix,
+// so this is a simple pass-through.
+func ConvertKiroStreamToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
+ return []string{string(rawResponse)}
+}
+
+// ConvertKiroNonStreamToClaude converts Kiro non-streaming response to Claude format.
+// The response is already in Claude format, so this is a pass-through.
+func ConvertKiroNonStreamToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
+ return string(rawResponse)
+}
diff --git a/internal/translator/kiro/claude/kiro_claude_request.go b/internal/translator/kiro/claude/kiro_claude_request.go
new file mode 100644
index 00000000..7012e644
--- /dev/null
+++ b/internal/translator/kiro/claude/kiro_claude_request.go
@@ -0,0 +1,958 @@
+// Package claude provides request translation functionality for Claude API to Kiro format.
+// It handles parsing and transforming Claude API requests into the Kiro/Amazon Q API format,
+// extracting model information, system instructions, message contents, and tool declarations.
+package claude
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+ "time"
+ "unicode/utf8"
+
+ "github.com/google/uuid"
+ kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
+ log "github.com/sirupsen/logrus"
+ "github.com/tidwall/gjson"
+)
+
+// remoteWebSearchDescription is a minimal fallback for when dynamic fetch from MCP tools/list hasn't completed yet.
+const remoteWebSearchDescription = "WebSearch looks up information outside the model's training data. Supports multiple queries to gather comprehensive information."
+
+// Kiro API request structs - field order determines JSON key order
+
+// KiroPayload is the top-level request structure for Kiro API
+type KiroPayload struct {
+ ConversationState KiroConversationState `json:"conversationState"`
+ ProfileArn string `json:"profileArn,omitempty"`
+ InferenceConfig *KiroInferenceConfig `json:"inferenceConfig,omitempty"`
+}
+
+// KiroInferenceConfig contains inference parameters for the Kiro API.
+type KiroInferenceConfig struct {
+ MaxTokens int `json:"maxTokens,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"topP,omitempty"`
+}
+
+// KiroConversationState holds the conversation context
+type KiroConversationState struct {
+ ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" - must be first field
+ ConversationID string `json:"conversationId"`
+ CurrentMessage KiroCurrentMessage `json:"currentMessage"`
+ History []KiroHistoryMessage `json:"history,omitempty"`
+}
+
+// KiroCurrentMessage wraps the current user message
+type KiroCurrentMessage struct {
+ UserInputMessage KiroUserInputMessage `json:"userInputMessage"`
+}
+
+// KiroHistoryMessage represents a message in the conversation history
+type KiroHistoryMessage struct {
+ UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"`
+ AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"`
+}
+
+// KiroImage represents an image in Kiro API format
+type KiroImage struct {
+ Format string `json:"format"`
+ Source KiroImageSource `json:"source"`
+}
+
+// KiroImageSource contains the image data
+type KiroImageSource struct {
+ Bytes string `json:"bytes"` // base64 encoded image data
+}
+
+// KiroUserInputMessage represents a user message
+type KiroUserInputMessage struct {
+ Content string `json:"content"`
+ ModelID string `json:"modelId"`
+ Origin string `json:"origin"`
+ Images []KiroImage `json:"images,omitempty"`
+ UserInputMessageContext *KiroUserInputMessageContext `json:"userInputMessageContext,omitempty"`
+}
+
+// KiroUserInputMessageContext contains tool-related context
+type KiroUserInputMessageContext struct {
+ ToolResults []KiroToolResult `json:"toolResults,omitempty"`
+ Tools []KiroToolWrapper `json:"tools,omitempty"`
+}
+
+// KiroToolResult represents a tool execution result
+type KiroToolResult struct {
+ Content []KiroTextContent `json:"content"`
+ Status string `json:"status"`
+ ToolUseID string `json:"toolUseId"`
+}
+
+// KiroTextContent represents text content
+type KiroTextContent struct {
+ Text string `json:"text"`
+}
+
+// KiroToolWrapper wraps a tool specification
+type KiroToolWrapper struct {
+ ToolSpecification KiroToolSpecification `json:"toolSpecification"`
+}
+
+// KiroToolSpecification defines a tool's schema
+type KiroToolSpecification struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ InputSchema KiroInputSchema `json:"inputSchema"`
+}
+
+// KiroInputSchema wraps the JSON schema for tool input
+type KiroInputSchema struct {
+ JSON interface{} `json:"json"`
+}
+
+// KiroAssistantResponseMessage represents an assistant message
+type KiroAssistantResponseMessage struct {
+ Content string `json:"content"`
+ ToolUses []KiroToolUse `json:"toolUses,omitempty"`
+}
+
+// KiroToolUse represents a tool invocation by the assistant
+type KiroToolUse struct {
+ ToolUseID string `json:"toolUseId"`
+ Name string `json:"name"`
+ Input map[string]interface{} `json:"input"`
+ IsTruncated bool `json:"-"` // Internal flag, not serialized
+ TruncationInfo *TruncationInfo `json:"-"` // Truncation details, not serialized
+}
+
+// ConvertClaudeRequestToKiro converts a Claude API request to Kiro format.
+// This is the main entry point for request translation.
+func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
+ // For Kiro, we pass through the Claude format since buildKiroPayload
+ // expects Claude format and does the conversion internally.
+ // The actual conversion happens in the executor when building the HTTP request.
+ return inputRawJSON
+}
+
+// BuildKiroPayload constructs the Kiro API request payload from Claude format.
+// Supports tool calling - tools are passed via userInputMessageContext.
+// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE.
+// isAgentic parameter enables chunked write optimization prompt for -agentic model variants.
+// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode).
+// headers parameter allows checking Anthropic-Beta header for thinking mode detection.
+// metadata parameter is kept for API compatibility but no longer used for thinking configuration.
+// Supports thinking mode - when enabled, injects thinking tags into system prompt.
+// Returns the payload and a boolean indicating whether thinking mode was injected.
+func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, headers http.Header, metadata map[string]any) ([]byte, bool) {
+ // Extract max_tokens for potential use in inferenceConfig
+ // Handle -1 as "use maximum" (Kiro max output is ~32000 tokens)
+ const kiroMaxOutputTokens = 32000
+ var maxTokens int64
+ if mt := gjson.GetBytes(claudeBody, "max_tokens"); mt.Exists() {
+ maxTokens = mt.Int()
+ if maxTokens == -1 {
+ maxTokens = kiroMaxOutputTokens
+ log.Debugf("kiro: max_tokens=-1 converted to %d", kiroMaxOutputTokens)
+ }
+ }
+
+ // Extract temperature if specified
+ var temperature float64
+ var hasTemperature bool
+ if temp := gjson.GetBytes(claudeBody, "temperature"); temp.Exists() {
+ temperature = temp.Float()
+ hasTemperature = true
+ }
+
+ // Extract top_p if specified
+ var topP float64
+ var hasTopP bool
+ if tp := gjson.GetBytes(claudeBody, "top_p"); tp.Exists() {
+ topP = tp.Float()
+ hasTopP = true
+ log.Debugf("kiro: extracted top_p: %.2f", topP)
+ }
+
+ // Normalize origin value for Kiro API compatibility
+ origin = normalizeOrigin(origin)
+ log.Debugf("kiro: normalized origin value: %s", origin)
+
+ messages := gjson.GetBytes(claudeBody, "messages")
+
+ // For chat-only mode, don't include tools
+ var tools gjson.Result
+ if !isChatOnly {
+ tools = gjson.GetBytes(claudeBody, "tools")
+ }
+
+ // Extract system prompt
+ systemPrompt := extractSystemPrompt(claudeBody)
+
+ // Check for thinking mode using the comprehensive IsThinkingEnabledWithHeaders function
+ // This supports Claude API format, OpenAI reasoning_effort, AMP/Cursor format, and Anthropic-Beta header
+ thinkingEnabled := IsThinkingEnabledWithHeaders(claudeBody, headers)
+
+ // Inject timestamp context
+ timestamp := time.Now().Format("2006-01-02 15:04:05 MST")
+ timestampContext := fmt.Sprintf("[Context: Current time is %s]", timestamp)
+ if systemPrompt != "" {
+ systemPrompt = timestampContext + "\n\n" + systemPrompt
+ } else {
+ systemPrompt = timestampContext
+ }
+ log.Debugf("kiro: injected timestamp context: %s", timestamp)
+
+ // Inject agentic optimization prompt for -agentic model variants
+ if isAgentic {
+ if systemPrompt != "" {
+ systemPrompt += "\n"
+ }
+ systemPrompt += kirocommon.KiroAgenticSystemPrompt
+ }
+
+ // Handle tool_choice parameter - Kiro doesn't support it natively, so we inject system prompt hints
+ // Claude tool_choice values: {"type": "auto/any/tool", "name": "..."}
+ toolChoiceHint := extractClaudeToolChoiceHint(claudeBody)
+ if toolChoiceHint != "" {
+ if systemPrompt != "" {
+ systemPrompt += "\n"
+ }
+ systemPrompt += toolChoiceHint
+ log.Debugf("kiro: injected tool_choice hint into system prompt")
+ }
+
+ // Convert Claude tools to Kiro format
+ kiroTools := convertClaudeToolsToKiro(tools)
+
+ // Thinking mode implementation:
+ // Kiro API supports official thinking/reasoning mode via tag.
+ // When set to "enabled", Kiro returns reasoning content as official reasoningContentEvent
+ // rather than inline tags in assistantResponseEvent.
+ // We cap max_thinking_length to reserve space for tool outputs and prevent truncation.
+ if thinkingEnabled {
+ thinkingHint := `enabled
+16000`
+ if systemPrompt != "" {
+ systemPrompt = thinkingHint + "\n\n" + systemPrompt
+ } else {
+ systemPrompt = thinkingHint
+ }
+ log.Infof("kiro: injected thinking prompt (official mode), has_tools: %v", len(kiroTools) > 0)
+ }
+
+ // Process messages and build history
+ history, currentUserMsg, currentToolResults := processMessages(messages, modelID, origin)
+
+ // Build content with system prompt (only on first turn to avoid re-injection)
+ if currentUserMsg != nil {
+ effectiveSystemPrompt := systemPrompt
+ if len(history) > 0 {
+ effectiveSystemPrompt = "" // Don't re-inject on subsequent turns
+ }
+ currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, effectiveSystemPrompt, currentToolResults)
+
+ // Deduplicate currentToolResults
+ currentToolResults = deduplicateToolResults(currentToolResults)
+
+ // Build userInputMessageContext with tools and tool results
+ if len(kiroTools) > 0 || len(currentToolResults) > 0 {
+ currentUserMsg.UserInputMessageContext = &KiroUserInputMessageContext{
+ Tools: kiroTools,
+ ToolResults: currentToolResults,
+ }
+ }
+ }
+
+ // Build payload
+ var currentMessage KiroCurrentMessage
+ if currentUserMsg != nil {
+ currentMessage = KiroCurrentMessage{UserInputMessage: *currentUserMsg}
+ } else {
+ fallbackContent := ""
+ if systemPrompt != "" {
+ fallbackContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n"
+ }
+ currentMessage = KiroCurrentMessage{UserInputMessage: KiroUserInputMessage{
+ Content: fallbackContent,
+ ModelID: modelID,
+ Origin: origin,
+ }}
+ }
+
+ // Build inferenceConfig if we have any inference parameters
+ // Note: Kiro API doesn't actually use max_tokens for thinking budget
+ var inferenceConfig *KiroInferenceConfig
+ if maxTokens > 0 || hasTemperature || hasTopP {
+ inferenceConfig = &KiroInferenceConfig{}
+ if maxTokens > 0 {
+ inferenceConfig.MaxTokens = int(maxTokens)
+ }
+ if hasTemperature {
+ inferenceConfig.Temperature = temperature
+ }
+ if hasTopP {
+ inferenceConfig.TopP = topP
+ }
+ }
+
+ payload := KiroPayload{
+ ConversationState: KiroConversationState{
+ ChatTriggerType: "MANUAL",
+ ConversationID: uuid.New().String(),
+ CurrentMessage: currentMessage,
+ History: history,
+ },
+ ProfileArn: profileArn,
+ InferenceConfig: inferenceConfig,
+ }
+
+ result, err := json.Marshal(payload)
+ if err != nil {
+ log.Debugf("kiro: failed to marshal payload: %v", err)
+ return nil, false
+ }
+
+ return result, thinkingEnabled
+}
+
+// normalizeOrigin normalizes origin value for Kiro API compatibility
+func normalizeOrigin(origin string) string {
+ switch origin {
+ case "KIRO_CLI":
+ return "CLI"
+ case "KIRO_AI_EDITOR":
+ return "AI_EDITOR"
+ case "AMAZON_Q":
+ return "CLI"
+ case "KIRO_IDE":
+ return "AI_EDITOR"
+ default:
+ return origin
+ }
+}
+
+// extractSystemPrompt extracts system prompt from Claude request
+func extractSystemPrompt(claudeBody []byte) string {
+ systemField := gjson.GetBytes(claudeBody, "system")
+ if systemField.IsArray() {
+ var sb strings.Builder
+ for _, block := range systemField.Array() {
+ if block.Get("type").String() == "text" {
+ sb.WriteString(block.Get("text").String())
+ } else if block.Type == gjson.String {
+ sb.WriteString(block.String())
+ }
+ }
+ return sb.String()
+ }
+ return systemField.String()
+}
+
+// checkThinkingMode checks if thinking mode is enabled in the Claude request
+func checkThinkingMode(claudeBody []byte) (bool, int64) {
+ thinkingEnabled := false
+ var budgetTokens int64 = 24000
+
+ thinkingField := gjson.GetBytes(claudeBody, "thinking")
+ if thinkingField.Exists() {
+ thinkingType := thinkingField.Get("type").String()
+ if thinkingType == "enabled" {
+ thinkingEnabled = true
+ if bt := thinkingField.Get("budget_tokens"); bt.Exists() {
+ budgetTokens = bt.Int()
+ if budgetTokens <= 0 {
+ thinkingEnabled = false
+ log.Debugf("kiro: thinking mode disabled via budget_tokens <= 0")
+ }
+ }
+ if thinkingEnabled {
+ log.Debugf("kiro: thinking mode enabled via Claude API parameter, budget_tokens: %d", budgetTokens)
+ }
+ }
+ }
+
+ return thinkingEnabled, budgetTokens
+}
+
+// hasThinkingTagInBody checks if the request body already contains thinking configuration tags.
+// This is used to prevent duplicate injection when client (e.g., AMP/Cursor) already includes thinking config.
+func hasThinkingTagInBody(body []byte) bool {
+ bodyStr := string(body)
+ return strings.Contains(bodyStr, "") || strings.Contains(bodyStr, "")
+}
+
+// IsThinkingEnabledFromHeader checks if thinking mode is enabled via Anthropic-Beta header.
+// Claude CLI uses "Anthropic-Beta: interleaved-thinking-2025-05-14" to enable thinking.
+func IsThinkingEnabledFromHeader(headers http.Header) bool {
+ if headers == nil {
+ return false
+ }
+ betaHeader := headers.Get("Anthropic-Beta")
+ if betaHeader == "" {
+ return false
+ }
+ // Check for interleaved-thinking beta feature
+ if strings.Contains(betaHeader, "interleaved-thinking") {
+ log.Debugf("kiro: thinking mode enabled via Anthropic-Beta header: %s", betaHeader)
+ return true
+ }
+ return false
+}
+
+// IsThinkingEnabled is a public wrapper to check if thinking mode is enabled.
+// This is used by the executor to determine whether to parse tags in responses.
+// When thinking is NOT enabled in the request, tags in responses should be
+// treated as regular text content, not as thinking blocks.
+//
+// Supports multiple formats:
+// - Claude API format: thinking.type = "enabled"
+// - OpenAI format: reasoning_effort parameter
+// - AMP/Cursor format: interleaved in system prompt
+func IsThinkingEnabled(body []byte) bool {
+ return IsThinkingEnabledWithHeaders(body, nil)
+}
+
+// IsThinkingEnabledWithHeaders checks if thinking mode is enabled from body or headers.
+// This is the comprehensive check that supports all thinking detection methods:
+// - Claude API format: thinking.type = "enabled"
+// - OpenAI format: reasoning_effort parameter
+// - AMP/Cursor format: interleaved in system prompt
+// - Anthropic-Beta header: interleaved-thinking-2025-05-14
+func IsThinkingEnabledWithHeaders(body []byte, headers http.Header) bool {
+ // Check Anthropic-Beta header first (Claude Code uses this)
+ if IsThinkingEnabledFromHeader(headers) {
+ return true
+ }
+
+ // Check Claude API format first (thinking.type = "enabled")
+ enabled, _ := checkThinkingMode(body)
+ if enabled {
+ log.Debugf("kiro: IsThinkingEnabled returning true (Claude API format)")
+ return true
+ }
+
+ // Check OpenAI format: reasoning_effort parameter
+ // Valid values: "low", "medium", "high", "auto" (not "none")
+ reasoningEffort := gjson.GetBytes(body, "reasoning_effort")
+ if reasoningEffort.Exists() {
+ effort := reasoningEffort.String()
+ if effort != "" && effort != "none" {
+ log.Debugf("kiro: thinking mode enabled via OpenAI reasoning_effort: %s", effort)
+ return true
+ }
+ }
+
+ // Check AMP/Cursor format: interleaved in system prompt
+ // This is how AMP client passes thinking configuration
+ bodyStr := string(body)
+ if strings.Contains(bodyStr, "") && strings.Contains(bodyStr, "") {
+ // Extract thinking mode value
+ startTag := ""
+ endTag := ""
+ startIdx := strings.Index(bodyStr, startTag)
+ if startIdx >= 0 {
+ startIdx += len(startTag)
+ endIdx := strings.Index(bodyStr[startIdx:], endTag)
+ if endIdx >= 0 {
+ thinkingMode := bodyStr[startIdx : startIdx+endIdx]
+ if thinkingMode == "interleaved" || thinkingMode == "enabled" {
+ log.Debugf("kiro: thinking mode enabled via AMP/Cursor format: %s", thinkingMode)
+ return true
+ }
+ }
+ }
+ }
+
+ // Check OpenAI format: max_completion_tokens with reasoning (o1-style)
+ // Some clients use this to indicate reasoning mode
+ if gjson.GetBytes(body, "max_completion_tokens").Exists() {
+ // If max_completion_tokens is set, check if model name suggests reasoning
+ model := gjson.GetBytes(body, "model").String()
+ if strings.Contains(strings.ToLower(model), "thinking") ||
+ strings.Contains(strings.ToLower(model), "reason") {
+ log.Debugf("kiro: thinking mode enabled via model name hint: %s", model)
+ return true
+ }
+ }
+
+ log.Debugf("kiro: IsThinkingEnabled returning false (no thinking mode detected)")
+ return false
+}
+
+// shortenToolNameIfNeeded shortens tool names that exceed 64 characters.
+// MCP tools often have long names like "mcp__server-name__tool-name".
+// This preserves the "mcp__" prefix and last segment when possible.
+func shortenToolNameIfNeeded(name string) string {
+ const limit = 64
+ if len(name) <= limit {
+ return name
+ }
+ // For MCP tools, try to preserve prefix and last segment
+ if strings.HasPrefix(name, "mcp__") {
+ idx := strings.LastIndex(name, "__")
+ if idx > 0 {
+ cand := "mcp__" + name[idx+2:]
+ if len(cand) > limit {
+ return cand[:limit]
+ }
+ return cand
+ }
+ }
+ return name[:limit]
+}
+
+func ensureKiroInputSchema(parameters interface{}) interface{} {
+ if parameters != nil {
+ return parameters
+ }
+ return map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{},
+ }
+}
+
+// convertClaudeToolsToKiro converts Claude tools to Kiro format
+func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
+ var kiroTools []KiroToolWrapper
+ if !tools.IsArray() {
+ return kiroTools
+ }
+
+ for _, tool := range tools.Array() {
+ name := tool.Get("name").String()
+ description := tool.Get("description").String()
+ inputSchemaResult := tool.Get("input_schema")
+ var inputSchema interface{}
+ if inputSchemaResult.Exists() && inputSchemaResult.Type != gjson.Null {
+ inputSchema = inputSchemaResult.Value()
+ }
+ inputSchema = ensureKiroInputSchema(inputSchema)
+
+ // Shorten tool name if it exceeds 64 characters (common with MCP tools)
+ originalName := name
+ name = shortenToolNameIfNeeded(name)
+ if name != originalName {
+ log.Debugf("kiro: shortened tool name from '%s' to '%s'", originalName, name)
+ }
+
+ // CRITICAL FIX: Kiro API requires non-empty description
+ if strings.TrimSpace(description) == "" {
+ description = fmt.Sprintf("Tool: %s", name)
+ log.Debugf("kiro: tool '%s' has empty description, using default: %s", name, description)
+ }
+
+ // Rename web_search → remote_web_search for Kiro API compatibility
+ if name == "web_search" {
+ name = "remote_web_search"
+ // Prefer dynamically fetched description, fall back to hardcoded constant
+ if cached := GetWebSearchDescription(); cached != "" {
+ description = cached
+ } else {
+ description = remoteWebSearchDescription
+ }
+ log.Debugf("kiro: renamed tool web_search → remote_web_search")
+ }
+
+ // Truncate long descriptions (individual tool limit)
+ if len(description) > kirocommon.KiroMaxToolDescLen {
+ truncLen := kirocommon.KiroMaxToolDescLen - 30
+ for truncLen > 0 && !utf8.RuneStart(description[truncLen]) {
+ truncLen--
+ }
+ description = description[:truncLen] + "... (description truncated)"
+ }
+
+ kiroTools = append(kiroTools, KiroToolWrapper{
+ ToolSpecification: KiroToolSpecification{
+ Name: name,
+ Description: description,
+ InputSchema: KiroInputSchema{JSON: inputSchema},
+ },
+ })
+ }
+
+ // Apply dynamic compression if total tools size exceeds threshold
+ // This prevents 500 errors when Claude Code sends too many tools
+ kiroTools = compressToolsIfNeeded(kiroTools)
+
+ return kiroTools
+}
+
+// processMessages processes Claude messages and builds Kiro history
+func processMessages(messages gjson.Result, modelID, origin string) ([]KiroHistoryMessage, *KiroUserInputMessage, []KiroToolResult) {
+ var history []KiroHistoryMessage
+ var currentUserMsg *KiroUserInputMessage
+ var currentToolResults []KiroToolResult
+
+ // Merge adjacent messages with the same role
+ messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
+
+ // FIX: Kiro API requires history to start with a user message.
+ // Some clients (e.g., OpenClaw) send conversations starting with an assistant message,
+ // which is valid for the Claude API but causes "Improperly formed request" on Kiro.
+ // Prepend a placeholder user message so the history alternation is correct.
+ if len(messagesArray) > 0 && messagesArray[0].Get("role").String() == "assistant" {
+ placeholder := `{"role":"user","content":"."}`
+ messagesArray = append([]gjson.Result{gjson.Parse(placeholder)}, messagesArray...)
+ log.Infof("kiro: messages started with assistant role, prepended placeholder user message for Kiro API compatibility")
+ }
+
+ for i, msg := range messagesArray {
+ role := msg.Get("role").String()
+ isLastMessage := i == len(messagesArray)-1
+
+ if role == "user" {
+ userMsg, toolResults := BuildUserMessageStruct(msg, modelID, origin)
+ // CRITICAL: Kiro API requires content to be non-empty for ALL user messages
+ // This includes both history messages and the current message.
+ // When user message contains only tool_result (no text), content will be empty.
+ // This commonly happens in compaction requests from OpenCode.
+ if strings.TrimSpace(userMsg.Content) == "" {
+ if len(toolResults) > 0 {
+ userMsg.Content = kirocommon.DefaultUserContentWithToolResults
+ } else {
+ userMsg.Content = kirocommon.DefaultUserContent
+ }
+ log.Debugf("kiro: user content was empty, using default: %s", userMsg.Content)
+ }
+ if isLastMessage {
+ currentUserMsg = &userMsg
+ currentToolResults = toolResults
+ } else {
+ // For history messages, embed tool results in context
+ if len(toolResults) > 0 {
+ userMsg.UserInputMessageContext = &KiroUserInputMessageContext{
+ ToolResults: toolResults,
+ }
+ }
+ history = append(history, KiroHistoryMessage{
+ UserInputMessage: &userMsg,
+ })
+ }
+ } else if role == "assistant" {
+ assistantMsg := BuildAssistantMessageStruct(msg)
+ if isLastMessage {
+ history = append(history, KiroHistoryMessage{
+ AssistantResponseMessage: &assistantMsg,
+ })
+ // Create a "Continue" user message as currentMessage
+ currentUserMsg = &KiroUserInputMessage{
+ Content: "Continue",
+ ModelID: modelID,
+ Origin: origin,
+ }
+ } else {
+ history = append(history, KiroHistoryMessage{
+ AssistantResponseMessage: &assistantMsg,
+ })
+ }
+ }
+ }
+
+ // POST-PROCESSING: Remove orphaned tool_results that have no matching tool_use
+ // in any assistant message. This happens when Claude Code compaction truncates
+ // the conversation and removes the assistant message containing the tool_use,
+ // but keeps the user message with the corresponding tool_result.
+ // Without this fix, Kiro API returns "Improperly formed request".
+ validToolUseIDs := make(map[string]bool)
+ for _, h := range history {
+ if h.AssistantResponseMessage != nil {
+ for _, tu := range h.AssistantResponseMessage.ToolUses {
+ validToolUseIDs[tu.ToolUseID] = true
+ }
+ }
+ }
+
+ // Filter orphaned tool results from history user messages
+ for i, h := range history {
+ if h.UserInputMessage != nil && h.UserInputMessage.UserInputMessageContext != nil {
+ ctx := h.UserInputMessage.UserInputMessageContext
+ if len(ctx.ToolResults) > 0 {
+ filtered := make([]KiroToolResult, 0, len(ctx.ToolResults))
+ for _, tr := range ctx.ToolResults {
+ if validToolUseIDs[tr.ToolUseID] {
+ filtered = append(filtered, tr)
+ } else {
+ log.Debugf("kiro: dropping orphaned tool_result in history[%d]: toolUseId=%s (no matching tool_use)", i, tr.ToolUseID)
+ }
+ }
+ ctx.ToolResults = filtered
+ if len(ctx.ToolResults) == 0 && len(ctx.Tools) == 0 {
+ h.UserInputMessage.UserInputMessageContext = nil
+ }
+ }
+ }
+ }
+
+ // Filter orphaned tool results from current message
+ if len(currentToolResults) > 0 {
+ filtered := make([]KiroToolResult, 0, len(currentToolResults))
+ for _, tr := range currentToolResults {
+ if validToolUseIDs[tr.ToolUseID] {
+ filtered = append(filtered, tr)
+ } else {
+ log.Debugf("kiro: dropping orphaned tool_result in currentMessage: toolUseId=%s (no matching tool_use)", tr.ToolUseID)
+ }
+ }
+ if len(filtered) != len(currentToolResults) {
+ log.Infof("kiro: dropped %d orphaned tool_result(s) from currentMessage (compaction artifact)", len(currentToolResults)-len(filtered))
+ }
+ currentToolResults = filtered
+ }
+
+ return history, currentUserMsg, currentToolResults
+}
+
+// buildFinalContent builds the final content with system prompt
+func buildFinalContent(content, systemPrompt string, toolResults []KiroToolResult) string {
+ var contentBuilder strings.Builder
+
+ if systemPrompt != "" {
+ contentBuilder.WriteString("--- SYSTEM PROMPT ---\n")
+ contentBuilder.WriteString(systemPrompt)
+ contentBuilder.WriteString("\n--- END SYSTEM PROMPT ---\n\n")
+ }
+
+ contentBuilder.WriteString(content)
+ finalContent := contentBuilder.String()
+
+ // CRITICAL: Kiro API requires content to be non-empty
+ if strings.TrimSpace(finalContent) == "" {
+ if len(toolResults) > 0 {
+ finalContent = "Tool results provided."
+ } else {
+ finalContent = "Continue"
+ }
+ log.Debugf("kiro: content was empty, using default: %s", finalContent)
+ }
+
+ return finalContent
+}
+
+// deduplicateToolResults removes duplicate tool results
+func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult {
+ if len(toolResults) == 0 {
+ return toolResults
+ }
+
+ seenIDs := make(map[string]bool)
+ unique := make([]KiroToolResult, 0, len(toolResults))
+ for _, tr := range toolResults {
+ if !seenIDs[tr.ToolUseID] {
+ seenIDs[tr.ToolUseID] = true
+ unique = append(unique, tr)
+ } else {
+ log.Debugf("kiro: skipping duplicate toolResult in currentMessage: %s", tr.ToolUseID)
+ }
+ }
+ return unique
+}
+
+// extractClaudeToolChoiceHint extracts tool_choice from Claude request and returns a system prompt hint.
+// Claude tool_choice values:
+// - {"type": "auto"}: Model decides (default, no hint needed)
+// - {"type": "any"}: Must use at least one tool
+// - {"type": "tool", "name": "..."}: Must use specific tool
+func extractClaudeToolChoiceHint(claudeBody []byte) string {
+ toolChoice := gjson.GetBytes(claudeBody, "tool_choice")
+ if !toolChoice.Exists() {
+ return ""
+ }
+
+ toolChoiceType := toolChoice.Get("type").String()
+ switch toolChoiceType {
+ case "any":
+ return "[INSTRUCTION: You MUST use at least one of the available tools to respond. Do not respond with text only - always make a tool call.]"
+ case "tool":
+ toolName := toolChoice.Get("name").String()
+ if toolName != "" {
+ return fmt.Sprintf("[INSTRUCTION: You MUST use the tool named '%s' to respond. Do not use any other tool or respond with text only.]", toolName)
+ }
+ case "auto":
+ // Default behavior, no hint needed
+ return ""
+ }
+
+ return ""
+}
+
+// BuildUserMessageStruct builds a user message and extracts tool results
+func BuildUserMessageStruct(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
+ content := msg.Get("content")
+ var contentBuilder strings.Builder
+ var toolResults []KiroToolResult
+ var images []KiroImage
+
+ // Track seen toolUseIds to deduplicate
+ seenToolUseIDs := make(map[string]bool)
+
+ if content.IsArray() {
+ for _, part := range content.Array() {
+ partType := part.Get("type").String()
+ switch partType {
+ case "text":
+ contentBuilder.WriteString(part.Get("text").String())
+ case "image":
+ mediaType := part.Get("source.media_type").String()
+ data := part.Get("source.data").String()
+
+ format := ""
+ if idx := strings.LastIndex(mediaType, "/"); idx != -1 {
+ format = mediaType[idx+1:]
+ }
+
+ if format != "" && data != "" {
+ images = append(images, KiroImage{
+ Format: format,
+ Source: KiroImageSource{
+ Bytes: data,
+ },
+ })
+ }
+ case "tool_result":
+ toolUseID := part.Get("tool_use_id").String()
+
+ // Skip duplicate toolUseIds
+ if seenToolUseIDs[toolUseID] {
+ log.Debugf("kiro: skipping duplicate tool_result with toolUseId: %s", toolUseID)
+ continue
+ }
+ seenToolUseIDs[toolUseID] = true
+
+ isError := part.Get("is_error").Bool()
+ resultContent := part.Get("content")
+
+ var textContents []KiroTextContent
+
+ // Check if this tool_result contains error from our SOFT_LIMIT_REACHED tool_use
+ // The client will return an error when trying to execute a tool with marker input
+ resultStr := resultContent.String()
+ isSoftLimitError := strings.Contains(resultStr, "SOFT_LIMIT_REACHED") ||
+ strings.Contains(resultStr, "_status") ||
+ strings.Contains(resultStr, "truncated") ||
+ strings.Contains(resultStr, "missing required") ||
+ strings.Contains(resultStr, "invalid input") ||
+ strings.Contains(resultStr, "Error writing file")
+
+ if isError && isSoftLimitError {
+ // Replace error content with SOFT_LIMIT_REACHED guidance
+ log.Infof("kiro: detected SOFT_LIMIT_REACHED in tool_result for %s, replacing with guidance", toolUseID)
+ softLimitMsg := `SOFT_LIMIT_REACHED
+
+Your previous tool call was incomplete due to API output size limits.
+The content was PARTIALLY transmitted but NOT executed.
+
+REQUIRED ACTION:
+1. Split your content into smaller chunks (max 300 lines per call)
+2. For file writes: Create file with first chunk, then use append for remaining
+3. Do NOT regenerate content you already attempted - continue from where you stopped
+
+STATUS: This is NOT an error. Continue with smaller chunks.`
+ textContents = append(textContents, KiroTextContent{Text: softLimitMsg})
+ // Mark as SUCCESS so Claude doesn't treat it as a failure
+ isError = false
+ } else if resultContent.IsArray() {
+ for _, item := range resultContent.Array() {
+ if item.Get("type").String() == "text" {
+ textContents = append(textContents, KiroTextContent{Text: item.Get("text").String()})
+ } else if item.Type == gjson.String {
+ textContents = append(textContents, KiroTextContent{Text: item.String()})
+ }
+ }
+ } else if resultContent.Type == gjson.String {
+ textContents = append(textContents, KiroTextContent{Text: resultContent.String()})
+ }
+
+ if len(textContents) == 0 {
+ textContents = append(textContents, KiroTextContent{Text: "Tool use was cancelled by the user"})
+ }
+
+ status := "success"
+ if isError {
+ status = "error"
+ }
+
+ toolResults = append(toolResults, KiroToolResult{
+ ToolUseID: toolUseID,
+ Content: textContents,
+ Status: status,
+ })
+ }
+ }
+ } else {
+ contentBuilder.WriteString(content.String())
+ }
+
+ userMsg := KiroUserInputMessage{
+ Content: contentBuilder.String(),
+ ModelID: modelID,
+ Origin: origin,
+ }
+
+ if len(images) > 0 {
+ userMsg.Images = images
+ }
+
+ return userMsg, toolResults
+}
+
+// BuildAssistantMessageStruct builds an assistant message with tool uses
+func BuildAssistantMessageStruct(msg gjson.Result) KiroAssistantResponseMessage {
+ content := msg.Get("content")
+ var contentBuilder strings.Builder
+ var toolUses []KiroToolUse
+
+ if content.IsArray() {
+ for _, part := range content.Array() {
+ partType := part.Get("type").String()
+ switch partType {
+ case "text":
+ contentBuilder.WriteString(part.Get("text").String())
+ case "tool_use":
+ toolUseID := part.Get("id").String()
+ toolName := part.Get("name").String()
+ toolInput := part.Get("input")
+
+ var inputMap map[string]interface{}
+ if toolInput.IsObject() {
+ inputMap = make(map[string]interface{})
+ toolInput.ForEach(func(key, value gjson.Result) bool {
+ inputMap[key.String()] = value.Value()
+ return true
+ })
+ }
+
+ // Rename web_search → remote_web_search to match convertClaudeToolsToKiro
+ if toolName == "web_search" {
+ toolName = "remote_web_search"
+ }
+
+ toolUses = append(toolUses, KiroToolUse{
+ ToolUseID: toolUseID,
+ Name: toolName,
+ Input: inputMap,
+ })
+ }
+ }
+ } else {
+ contentBuilder.WriteString(content.String())
+ }
+
+ // CRITICAL FIX: Kiro API requires non-empty content for assistant messages
+ // This can happen with compaction requests where assistant messages have only tool_use
+ // (no text content). Without this fix, Kiro API returns "Improperly formed request" error.
+ finalContent := contentBuilder.String()
+ if strings.TrimSpace(finalContent) == "" {
+ if len(toolUses) > 0 {
+ finalContent = kirocommon.DefaultAssistantContentWithTools
+ } else {
+ finalContent = kirocommon.DefaultAssistantContent
+ }
+ log.Debugf("kiro: assistant content was empty, using default: %s", finalContent)
+ }
+
+ return KiroAssistantResponseMessage{
+ Content: finalContent,
+ ToolUses: toolUses,
+ }
+}
diff --git a/internal/translator/kiro/claude/kiro_claude_response.go b/internal/translator/kiro/claude/kiro_claude_response.go
new file mode 100644
index 00000000..89a760cd
--- /dev/null
+++ b/internal/translator/kiro/claude/kiro_claude_response.go
@@ -0,0 +1,230 @@
+// Package claude provides response translation functionality for Kiro API to Claude format.
+// This package handles the conversion of Kiro API responses into Claude-compatible format,
+// including support for thinking blocks and tool use.
+package claude
+
+import (
+ "crypto/sha256"
+ "encoding/base64"
+ "encoding/json"
+ "strings"
+
+ "github.com/google/uuid"
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
+ log "github.com/sirupsen/logrus"
+
+ kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
+)
+
+// generateThinkingSignature generates a signature for thinking content.
+// This is required by Claude API for thinking blocks in non-streaming responses.
+// The signature is a base64-encoded hash of the thinking content.
+func generateThinkingSignature(thinkingContent string) string {
+ if thinkingContent == "" {
+ return ""
+ }
+ // Generate a deterministic signature based on content hash
+ hash := sha256.Sum256([]byte(thinkingContent))
+ return base64.StdEncoding.EncodeToString(hash[:])
+}
+
+// Local references to kirocommon constants for thinking block parsing
+var (
+ thinkingStartTag = kirocommon.ThinkingStartTag
+ thinkingEndTag = kirocommon.ThinkingEndTag
+)
+
+// BuildClaudeResponse constructs a Claude-compatible response.
+// Supports tool_use blocks when tools are present in the response.
+// Supports thinking blocks - parses tags and converts to Claude thinking content blocks.
+// stopReason is passed from upstream; fallback logic applied if empty.
+func BuildClaudeResponse(content string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte {
+ var contentBlocks []map[string]interface{}
+
+ // Extract thinking blocks and text from content
+ if content != "" {
+ blocks := ExtractThinkingFromContent(content)
+ contentBlocks = append(contentBlocks, blocks...)
+
+ // Log if thinking blocks were extracted
+ for _, block := range blocks {
+ if block["type"] == "thinking" {
+ thinkingContent := block["thinking"].(string)
+ log.Infof("kiro: buildClaudeResponse extracted thinking block (len: %d)", len(thinkingContent))
+ }
+ }
+ }
+
+ // Add tool_use blocks - emit truncated tools with SOFT_LIMIT_REACHED marker
+ hasTruncatedTools := false
+ for _, toolUse := range toolUses {
+ if toolUse.IsTruncated && toolUse.TruncationInfo != nil {
+ // Emit tool_use with SOFT_LIMIT_REACHED marker input
+ hasTruncatedTools = true
+ log.Infof("kiro: buildClaudeResponse emitting truncated tool with SOFT_LIMIT_REACHED: %s (ID: %s)", toolUse.Name, toolUse.ToolUseID)
+
+ markerInput := map[string]interface{}{
+ "_status": "SOFT_LIMIT_REACHED",
+ "_message": "Tool output was truncated. Split content into smaller chunks (max 300 lines). Due to potential model hallucination, you MUST re-fetch the current working directory and generate the correct file_path.",
+ }
+
+ contentBlocks = append(contentBlocks, map[string]interface{}{
+ "type": "tool_use",
+ "id": toolUse.ToolUseID,
+ "name": toolUse.Name,
+ "input": markerInput,
+ })
+ } else {
+ // Normal tool use
+ contentBlocks = append(contentBlocks, map[string]interface{}{
+ "type": "tool_use",
+ "id": toolUse.ToolUseID,
+ "name": toolUse.Name,
+ "input": toolUse.Input,
+ })
+ }
+ }
+
+ // Log if we used SOFT_LIMIT_REACHED
+ if hasTruncatedTools {
+ log.Infof("kiro: buildClaudeResponse using SOFT_LIMIT_REACHED - keeping stop_reason=tool_use")
+ }
+
+ // Ensure at least one content block (Claude API requires non-empty content)
+ if len(contentBlocks) == 0 {
+ contentBlocks = append(contentBlocks, map[string]interface{}{
+ "type": "text",
+ "text": "",
+ })
+ }
+
+ // Use upstream stopReason; apply fallback logic if not provided
+ // SOFT_LIMIT_REACHED: Keep stop_reason = "tool_use" so Claude continues the loop
+ if stopReason == "" {
+ stopReason = "end_turn"
+ if len(toolUses) > 0 {
+ stopReason = "tool_use"
+ }
+ log.Debugf("kiro: buildClaudeResponse using fallback stop_reason: %s", stopReason)
+ }
+
+ // Log warning if response was truncated due to max_tokens
+ if stopReason == "max_tokens" {
+ log.Warnf("kiro: response truncated due to max_tokens limit (buildClaudeResponse)")
+ }
+
+ response := map[string]interface{}{
+ "id": "msg_" + uuid.New().String()[:24],
+ "type": "message",
+ "role": "assistant",
+ "model": model,
+ "content": contentBlocks,
+ "stop_reason": stopReason,
+ "usage": map[string]interface{}{
+ "input_tokens": usageInfo.InputTokens,
+ "output_tokens": usageInfo.OutputTokens,
+ },
+ }
+ result, _ := json.Marshal(response)
+ return result
+}
+
+// ExtractThinkingFromContent parses content to extract thinking blocks and text.
+// Returns a list of content blocks in the order they appear in the content.
+// Handles interleaved thinking and text blocks correctly.
+func ExtractThinkingFromContent(content string) []map[string]interface{} {
+ var blocks []map[string]interface{}
+
+ if content == "" {
+ return blocks
+ }
+
+ // Check if content contains thinking tags at all
+ if !strings.Contains(content, thinkingStartTag) {
+ // No thinking tags, return as plain text
+ return []map[string]interface{}{
+ {
+ "type": "text",
+ "text": content,
+ },
+ }
+ }
+
+ log.Debugf("kiro: extractThinkingFromContent - found thinking tags in content (len: %d)", len(content))
+
+ remaining := content
+
+ for len(remaining) > 0 {
+ // Look for tag
+ startIdx := strings.Index(remaining, thinkingStartTag)
+
+ if startIdx == -1 {
+ // No more thinking tags, add remaining as text
+ if strings.TrimSpace(remaining) != "" {
+ blocks = append(blocks, map[string]interface{}{
+ "type": "text",
+ "text": remaining,
+ })
+ }
+ break
+ }
+
+ // Add text before thinking tag (if any meaningful content)
+ if startIdx > 0 {
+ textBefore := remaining[:startIdx]
+ if strings.TrimSpace(textBefore) != "" {
+ blocks = append(blocks, map[string]interface{}{
+ "type": "text",
+ "text": textBefore,
+ })
+ }
+ }
+
+ // Move past the opening tag
+ remaining = remaining[startIdx+len(thinkingStartTag):]
+
+ // Find closing tag
+ endIdx := strings.Index(remaining, thinkingEndTag)
+
+ if endIdx == -1 {
+ // No closing tag found, treat rest as thinking content (incomplete response)
+ if strings.TrimSpace(remaining) != "" {
+ // Generate signature for thinking content (required by Claude API)
+ signature := generateThinkingSignature(remaining)
+ blocks = append(blocks, map[string]interface{}{
+ "type": "thinking",
+ "thinking": remaining,
+ "signature": signature,
+ })
+ log.Warnf("kiro: extractThinkingFromContent - missing closing tag")
+ }
+ break
+ }
+
+ // Extract thinking content between tags
+ thinkContent := remaining[:endIdx]
+ if strings.TrimSpace(thinkContent) != "" {
+ // Generate signature for thinking content (required by Claude API)
+ signature := generateThinkingSignature(thinkContent)
+ blocks = append(blocks, map[string]interface{}{
+ "type": "thinking",
+ "thinking": thinkContent,
+ "signature": signature,
+ })
+ log.Debugf("kiro: extractThinkingFromContent - extracted thinking block (len: %d)", len(thinkContent))
+ }
+
+ // Move past the closing tag
+ remaining = remaining[endIdx+len(thinkingEndTag):]
+ }
+
+ // If no blocks were created (all whitespace), return empty text block
+ if len(blocks) == 0 {
+ blocks = append(blocks, map[string]interface{}{
+ "type": "text",
+ "text": "",
+ })
+ }
+
+ return blocks
+}
diff --git a/internal/translator/kiro/claude/kiro_claude_stream.go b/internal/translator/kiro/claude/kiro_claude_stream.go
new file mode 100644
index 00000000..c86b6e02
--- /dev/null
+++ b/internal/translator/kiro/claude/kiro_claude_stream.go
@@ -0,0 +1,306 @@
+// Package claude provides streaming SSE event building for Claude format.
+// This package handles the construction of Claude-compatible Server-Sent Events (SSE)
+// for streaming responses from Kiro API.
+package claude
+
+import (
+ "encoding/json"
+
+ "github.com/google/uuid"
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
+)
+
+// BuildClaudeMessageStartEvent creates the message_start SSE event
+func BuildClaudeMessageStartEvent(model string, inputTokens int64) []byte {
+ event := map[string]interface{}{
+ "type": "message_start",
+ "message": map[string]interface{}{
+ "id": "msg_" + uuid.New().String()[:24],
+ "type": "message",
+ "role": "assistant",
+ "content": []interface{}{},
+ "model": model,
+ "stop_reason": nil,
+ "stop_sequence": nil,
+ "usage": map[string]interface{}{"input_tokens": inputTokens, "output_tokens": 0},
+ },
+ }
+ result, _ := json.Marshal(event)
+ return []byte("event: message_start\ndata: " + string(result))
+}
+
+// BuildClaudeContentBlockStartEvent creates a content_block_start SSE event
+func BuildClaudeContentBlockStartEvent(index int, blockType, toolUseID, toolName string) []byte {
+ var contentBlock map[string]interface{}
+ switch blockType {
+ case "tool_use":
+ contentBlock = map[string]interface{}{
+ "type": "tool_use",
+ "id": toolUseID,
+ "name": toolName,
+ "input": map[string]interface{}{},
+ }
+ case "thinking":
+ contentBlock = map[string]interface{}{
+ "type": "thinking",
+ "thinking": "",
+ }
+ default:
+ contentBlock = map[string]interface{}{
+ "type": "text",
+ "text": "",
+ }
+ }
+
+ event := map[string]interface{}{
+ "type": "content_block_start",
+ "index": index,
+ "content_block": contentBlock,
+ }
+ result, _ := json.Marshal(event)
+ return []byte("event: content_block_start\ndata: " + string(result))
+}
+
+// BuildClaudeStreamEvent creates a text_delta content_block_delta SSE event
+func BuildClaudeStreamEvent(contentDelta string, index int) []byte {
+ event := map[string]interface{}{
+ "type": "content_block_delta",
+ "index": index,
+ "delta": map[string]interface{}{
+ "type": "text_delta",
+ "text": contentDelta,
+ },
+ }
+ result, _ := json.Marshal(event)
+ return []byte("event: content_block_delta\ndata: " + string(result))
+}
+
+// BuildClaudeInputJsonDeltaEvent creates an input_json_delta event for tool use streaming
+func BuildClaudeInputJsonDeltaEvent(partialJSON string, index int) []byte {
+ event := map[string]interface{}{
+ "type": "content_block_delta",
+ "index": index,
+ "delta": map[string]interface{}{
+ "type": "input_json_delta",
+ "partial_json": partialJSON,
+ },
+ }
+ result, _ := json.Marshal(event)
+ return []byte("event: content_block_delta\ndata: " + string(result))
+}
+
+// BuildClaudeContentBlockStopEvent creates a content_block_stop SSE event
+func BuildClaudeContentBlockStopEvent(index int) []byte {
+ event := map[string]interface{}{
+ "type": "content_block_stop",
+ "index": index,
+ }
+ result, _ := json.Marshal(event)
+ return []byte("event: content_block_stop\ndata: " + string(result))
+}
+
+// BuildClaudeThinkingBlockStopEvent creates a content_block_stop SSE event for thinking blocks.
+func BuildClaudeThinkingBlockStopEvent(index int) []byte {
+ event := map[string]interface{}{
+ "type": "content_block_stop",
+ "index": index,
+ }
+ result, _ := json.Marshal(event)
+ return []byte("event: content_block_stop\ndata: " + string(result))
+}
+
+// BuildClaudeMessageDeltaEvent creates the message_delta event with stop_reason and usage
+func BuildClaudeMessageDeltaEvent(stopReason string, usageInfo usage.Detail) []byte {
+ deltaEvent := map[string]interface{}{
+ "type": "message_delta",
+ "delta": map[string]interface{}{
+ "stop_reason": stopReason,
+ "stop_sequence": nil,
+ },
+ "usage": map[string]interface{}{
+ "input_tokens": usageInfo.InputTokens,
+ "output_tokens": usageInfo.OutputTokens,
+ },
+ }
+ deltaResult, _ := json.Marshal(deltaEvent)
+ return []byte("event: message_delta\ndata: " + string(deltaResult))
+}
+
+// BuildClaudeMessageStopOnlyEvent creates only the message_stop event
+func BuildClaudeMessageStopOnlyEvent() []byte {
+ stopEvent := map[string]interface{}{
+ "type": "message_stop",
+ }
+ stopResult, _ := json.Marshal(stopEvent)
+ return []byte("event: message_stop\ndata: " + string(stopResult))
+}
+
+// BuildClaudePingEventWithUsage creates a ping event with embedded usage information.
+// This is used for real-time usage estimation during streaming.
+func BuildClaudePingEventWithUsage(inputTokens, outputTokens int64) []byte {
+ event := map[string]interface{}{
+ "type": "ping",
+ "usage": map[string]interface{}{
+ "input_tokens": inputTokens,
+ "output_tokens": outputTokens,
+ "total_tokens": inputTokens + outputTokens,
+ "estimated": true,
+ },
+ }
+ result, _ := json.Marshal(event)
+ return []byte("event: ping\ndata: " + string(result))
+}
+
+// BuildClaudeThinkingDeltaEvent creates a thinking_delta event for Claude API compatibility.
+// This is used when streaming thinking content wrapped in tags.
+func BuildClaudeThinkingDeltaEvent(thinkingDelta string, index int) []byte {
+ event := map[string]interface{}{
+ "type": "content_block_delta",
+ "index": index,
+ "delta": map[string]interface{}{
+ "type": "thinking_delta",
+ "thinking": thinkingDelta,
+ },
+ }
+ result, _ := json.Marshal(event)
+ return []byte("event: content_block_delta\ndata: " + string(result))
+}
+
+// PendingTagSuffix detects if the buffer ends with a partial prefix of the given tag.
+// Returns the length of the partial match (0 if no match).
+// Based on amq2api implementation for handling cross-chunk tag boundaries.
+func PendingTagSuffix(buffer, tag string) int {
+ if buffer == "" || tag == "" {
+ return 0
+ }
+ maxLen := len(buffer)
+ if maxLen > len(tag)-1 {
+ maxLen = len(tag) - 1
+ }
+ for length := maxLen; length > 0; length-- {
+ if len(buffer) >= length && buffer[len(buffer)-length:] == tag[:length] {
+ return length
+ }
+ }
+ return 0
+}
+
+// GenerateSearchIndicatorEvents generates ONLY the search indicator SSE events
+// (server_tool_use + web_search_tool_result) without text summary or message termination.
+// These events trigger Claude Code's search indicator UI.
+// The caller is responsible for sending message_start before and message_delta/stop after.
+func GenerateSearchIndicatorEvents(
+ query string,
+ toolUseID string,
+ searchResults *WebSearchResults,
+ startIndex int,
+) [][]byte {
+ events := make([][]byte, 0, 5)
+
+ // 1. content_block_start (server_tool_use)
+ event1 := map[string]interface{}{
+ "type": "content_block_start",
+ "index": startIndex,
+ "content_block": map[string]interface{}{
+ "id": toolUseID,
+ "type": "server_tool_use",
+ "name": "web_search",
+ "input": map[string]interface{}{},
+ },
+ }
+ data1, _ := json.Marshal(event1)
+ events = append(events, []byte("event: content_block_start\ndata: "+string(data1)+"\n\n"))
+
+ // 2. content_block_delta (input_json_delta)
+ inputJSON, _ := json.Marshal(map[string]string{"query": query})
+ event2 := map[string]interface{}{
+ "type": "content_block_delta",
+ "index": startIndex,
+ "delta": map[string]interface{}{
+ "type": "input_json_delta",
+ "partial_json": string(inputJSON),
+ },
+ }
+ data2, _ := json.Marshal(event2)
+ events = append(events, []byte("event: content_block_delta\ndata: "+string(data2)+"\n\n"))
+
+ // 3. content_block_stop (server_tool_use)
+ event3 := map[string]interface{}{
+ "type": "content_block_stop",
+ "index": startIndex,
+ }
+ data3, _ := json.Marshal(event3)
+ events = append(events, []byte("event: content_block_stop\ndata: "+string(data3)+"\n\n"))
+
+ // 4. content_block_start (web_search_tool_result)
+ searchContent := make([]map[string]interface{}, 0)
+ if searchResults != nil {
+ for _, r := range searchResults.Results {
+ snippet := ""
+ if r.Snippet != nil {
+ snippet = *r.Snippet
+ }
+ searchContent = append(searchContent, map[string]interface{}{
+ "type": "web_search_result",
+ "title": r.Title,
+ "url": r.URL,
+ "encrypted_content": snippet,
+ "page_age": nil,
+ })
+ }
+ }
+ event4 := map[string]interface{}{
+ "type": "content_block_start",
+ "index": startIndex + 1,
+ "content_block": map[string]interface{}{
+ "type": "web_search_tool_result",
+ "tool_use_id": toolUseID,
+ "content": searchContent,
+ },
+ }
+ data4, _ := json.Marshal(event4)
+ events = append(events, []byte("event: content_block_start\ndata: "+string(data4)+"\n\n"))
+
+ // 5. content_block_stop (web_search_tool_result)
+ event5 := map[string]interface{}{
+ "type": "content_block_stop",
+ "index": startIndex + 1,
+ }
+ data5, _ := json.Marshal(event5)
+ events = append(events, []byte("event: content_block_stop\ndata: "+string(data5)+"\n\n"))
+
+ return events
+}
+
+// BuildFallbackTextEvents generates SSE events for a fallback text response
+// when the Kiro API fails during the search loop. Uses BuildClaude*Event()
+// functions to align with streamToChannel patterns.
+// Returns raw SSE byte slices ready to be sent to the client channel.
+func BuildFallbackTextEvents(contentBlockIndex int, query string, results *WebSearchResults) [][]byte {
+ summary := FormatSearchContextPrompt(query, results)
+ outputTokens := len(summary) / 4
+ if len(summary) > 0 && outputTokens == 0 {
+ outputTokens = 1
+ }
+
+ var events [][]byte
+
+ // content_block_start (text)
+ events = append(events, BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", ""))
+
+ // content_block_delta (text_delta)
+ events = append(events, BuildClaudeStreamEvent(summary, contentBlockIndex))
+
+ // content_block_stop
+ events = append(events, BuildClaudeContentBlockStopEvent(contentBlockIndex))
+
+ // message_delta with end_turn
+ events = append(events, BuildClaudeMessageDeltaEvent("end_turn", usage.Detail{
+ OutputTokens: int64(outputTokens),
+ }))
+
+ // message_stop
+ events = append(events, BuildClaudeMessageStopOnlyEvent())
+
+ return events
+}
diff --git a/internal/translator/kiro/claude/kiro_claude_stream_parser.go b/internal/translator/kiro/claude/kiro_claude_stream_parser.go
new file mode 100644
index 00000000..275196ac
--- /dev/null
+++ b/internal/translator/kiro/claude/kiro_claude_stream_parser.go
@@ -0,0 +1,350 @@
+package claude
+
+import (
+ "encoding/json"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
+)
+
+// sseEvent represents a Server-Sent Event
+type sseEvent struct {
+ Event string
+ Data interface{}
+}
+
+// ToSSEString converts the event to SSE wire format
+func (e *sseEvent) ToSSEString() string {
+ dataBytes, _ := json.Marshal(e.Data)
+ return "event: " + e.Event + "\ndata: " + string(dataBytes) + "\n\n"
+}
+
+// AdjustStreamIndices adjusts content block indices in SSE event data by adding an offset.
+// It also suppresses duplicate message_start events (returns shouldForward=false).
+// This is used to combine search indicator events (indices 0,1) with Kiro model response events.
+//
+// The data parameter is a single SSE "data:" line payload (JSON).
+// Returns: adjusted data, shouldForward (false = skip this event).
+func AdjustStreamIndices(data []byte, offset int) ([]byte, bool) {
+ if len(data) == 0 {
+ return data, true
+ }
+
+ // Quick check: parse the JSON
+ var event map[string]interface{}
+ if err := json.Unmarshal(data, &event); err != nil {
+ // Not valid JSON, pass through
+ return data, true
+ }
+
+ eventType, _ := event["type"].(string)
+
+ // Suppress duplicate message_start events
+ if eventType == "message_start" {
+ return data, false
+ }
+
+ // Adjust index for content_block events
+ switch eventType {
+ case "content_block_start", "content_block_delta", "content_block_stop":
+ if idx, ok := event["index"].(float64); ok {
+ event["index"] = int(idx) + offset
+ adjusted, err := json.Marshal(event)
+ if err != nil {
+ return data, true
+ }
+ return adjusted, true
+ }
+ }
+
+ // Pass through all other events unchanged (message_delta, message_stop, ping, etc.)
+ return data, true
+}
+
+// AdjustSSEChunk processes a raw SSE chunk (potentially containing multiple "event:/data:" pairs)
+// and adjusts content block indices. Suppresses duplicate message_start events.
+// Returns the adjusted chunk and whether it should be forwarded.
+func AdjustSSEChunk(chunk []byte, offset int) ([]byte, bool) {
+ chunkStr := string(chunk)
+
+ // Fast path: if no "data:" prefix, pass through
+ if !strings.Contains(chunkStr, "data: ") {
+ return chunk, true
+ }
+
+ var result strings.Builder
+ hasContent := false
+
+ lines := strings.Split(chunkStr, "\n")
+ for i := 0; i < len(lines); i++ {
+ line := lines[i]
+
+ if strings.HasPrefix(line, "data: ") {
+ dataPayload := strings.TrimPrefix(line, "data: ")
+ dataPayload = strings.TrimSpace(dataPayload)
+
+ if dataPayload == "[DONE]" {
+ result.WriteString(line + "\n")
+ hasContent = true
+ continue
+ }
+
+ adjusted, shouldForward := AdjustStreamIndices([]byte(dataPayload), offset)
+ if !shouldForward {
+ // Skip this event and its preceding "event:" line
+ // Also skip the trailing empty line
+ continue
+ }
+
+ result.WriteString("data: " + string(adjusted) + "\n")
+ hasContent = true
+ } else if strings.HasPrefix(line, "event: ") {
+ // Check if the next data line will be suppressed
+ if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
+ dataPayload := strings.TrimPrefix(lines[i+1], "data: ")
+ dataPayload = strings.TrimSpace(dataPayload)
+
+ var event map[string]interface{}
+ if err := json.Unmarshal([]byte(dataPayload), &event); err == nil {
+ if eventType, ok := event["type"].(string); ok && eventType == "message_start" {
+ // Skip both the event: and data: lines
+ i++ // skip the data: line too
+ continue
+ }
+ }
+ }
+ result.WriteString(line + "\n")
+ hasContent = true
+ } else {
+ result.WriteString(line + "\n")
+ if strings.TrimSpace(line) != "" {
+ hasContent = true
+ }
+ }
+ }
+
+ if !hasContent {
+ return nil, false
+ }
+
+ return []byte(result.String()), true
+}
+
+// BufferedStreamResult contains the analysis of buffered SSE chunks from a Kiro API response.
+type BufferedStreamResult struct {
+ // StopReason is the detected stop_reason from the stream (e.g., "end_turn", "tool_use")
+ StopReason string
+ // WebSearchQuery is the extracted query if the model requested another web_search
+ WebSearchQuery string
+ // WebSearchToolUseId is the tool_use ID from the model's response (needed for toolResults)
+ WebSearchToolUseId string
+ // HasWebSearchToolUse indicates whether the model requested web_search
+ HasWebSearchToolUse bool
+ // WebSearchToolUseIndex is the content_block index of the web_search tool_use
+ WebSearchToolUseIndex int
+}
+
+// AnalyzeBufferedStream scans buffered SSE chunks to detect stop_reason and web_search tool_use.
+// This is used in the search loop to determine if the model wants another search round.
+func AnalyzeBufferedStream(chunks [][]byte) BufferedStreamResult {
+ result := BufferedStreamResult{WebSearchToolUseIndex: -1}
+
+ // Track tool use state across chunks
+ var currentToolName string
+ var currentToolIndex int = -1
+ var toolInputBuilder strings.Builder
+
+ for _, chunk := range chunks {
+ chunkStr := string(chunk)
+ lines := strings.Split(chunkStr, "\n")
+ for _, line := range lines {
+ if !strings.HasPrefix(line, "data: ") {
+ continue
+ }
+ dataPayload := strings.TrimPrefix(line, "data: ")
+ dataPayload = strings.TrimSpace(dataPayload)
+ if dataPayload == "[DONE]" || dataPayload == "" {
+ continue
+ }
+
+ var event map[string]interface{}
+ if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
+ continue
+ }
+
+ eventType, _ := event["type"].(string)
+
+ switch eventType {
+ case "message_delta":
+ // Extract stop_reason from message_delta
+ if delta, ok := event["delta"].(map[string]interface{}); ok {
+ if sr, ok := delta["stop_reason"].(string); ok && sr != "" {
+ result.StopReason = sr
+ }
+ }
+
+ case "content_block_start":
+ // Detect tool_use content blocks
+ if cb, ok := event["content_block"].(map[string]interface{}); ok {
+ if cbType, ok := cb["type"].(string); ok && cbType == "tool_use" {
+ if name, ok := cb["name"].(string); ok {
+ currentToolName = strings.ToLower(name)
+ if idx, ok := event["index"].(float64); ok {
+ currentToolIndex = int(idx)
+ }
+ // Capture tool use ID for toolResults handshake
+ if id, ok := cb["id"].(string); ok {
+ result.WebSearchToolUseId = id
+ }
+ toolInputBuilder.Reset()
+ }
+ }
+ }
+
+ case "content_block_delta":
+ // Accumulate tool input JSON
+ if currentToolName != "" {
+ if delta, ok := event["delta"].(map[string]interface{}); ok {
+ if deltaType, ok := delta["type"].(string); ok && deltaType == "input_json_delta" {
+ if partial, ok := delta["partial_json"].(string); ok {
+ toolInputBuilder.WriteString(partial)
+ }
+ }
+ }
+ }
+
+ case "content_block_stop":
+ // Finalize tool use detection
+ if currentToolName == "web_search" || currentToolName == "websearch" || currentToolName == "remote_web_search" {
+ result.HasWebSearchToolUse = true
+ result.WebSearchToolUseIndex = currentToolIndex
+ // Extract query from accumulated input JSON
+ inputJSON := toolInputBuilder.String()
+ var input map[string]string
+ if err := json.Unmarshal([]byte(inputJSON), &input); err == nil {
+ if q, ok := input["query"]; ok {
+ result.WebSearchQuery = q
+ }
+ }
+ log.Debugf("kiro/websearch: detected web_search tool_use")
+ }
+ currentToolName = ""
+ currentToolIndex = -1
+ toolInputBuilder.Reset()
+ }
+ }
+ }
+
+ return result
+}
+
+// FilterChunksForClient processes buffered SSE chunks and removes web_search tool_use
+// content blocks. This prevents the client from seeing "Tool use" prompts for web_search
+// when the proxy is handling the search loop internally.
+// Also suppresses message_start and message_delta/message_stop events since those
+// are managed by the outer handleWebSearchStream.
+func FilterChunksForClient(chunks [][]byte, wsToolIndex int, indexOffset int) [][]byte {
+ var filtered [][]byte
+
+ for _, chunk := range chunks {
+ chunkStr := string(chunk)
+ lines := strings.Split(chunkStr, "\n")
+
+ var resultBuilder strings.Builder
+ hasContent := false
+
+ for i := 0; i < len(lines); i++ {
+ line := lines[i]
+
+ if strings.HasPrefix(line, "data: ") {
+ dataPayload := strings.TrimPrefix(line, "data: ")
+ dataPayload = strings.TrimSpace(dataPayload)
+
+ if dataPayload == "[DONE]" {
+ // Skip [DONE] — the outer loop manages stream termination
+ continue
+ }
+
+ var event map[string]interface{}
+ if err := json.Unmarshal([]byte(dataPayload), &event); err != nil {
+ resultBuilder.WriteString(line + "\n")
+ hasContent = true
+ continue
+ }
+
+ eventType, _ := event["type"].(string)
+
+ // Skip message_start (outer loop sends its own)
+ if eventType == "message_start" {
+ continue
+ }
+
+ // Skip message_delta and message_stop (outer loop manages these)
+ if eventType == "message_delta" || eventType == "message_stop" {
+ continue
+ }
+
+ // Check if this event belongs to the web_search tool_use block
+ if wsToolIndex >= 0 {
+ if idx, ok := event["index"].(float64); ok && int(idx) == wsToolIndex {
+ // Skip events for the web_search tool_use block
+ continue
+ }
+ }
+
+ // Apply index offset for remaining events
+ if indexOffset > 0 {
+ switch eventType {
+ case "content_block_start", "content_block_delta", "content_block_stop":
+ if idx, ok := event["index"].(float64); ok {
+ event["index"] = int(idx) + indexOffset
+ adjusted, err := json.Marshal(event)
+ if err == nil {
+ resultBuilder.WriteString("data: " + string(adjusted) + "\n")
+ hasContent = true
+ continue
+ }
+ }
+ }
+ }
+
+ resultBuilder.WriteString(line + "\n")
+ hasContent = true
+ } else if strings.HasPrefix(line, "event: ") {
+ // Check if the next data line will be suppressed
+ if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") {
+ nextData := strings.TrimPrefix(lines[i+1], "data: ")
+ nextData = strings.TrimSpace(nextData)
+
+ var nextEvent map[string]interface{}
+ if err := json.Unmarshal([]byte(nextData), &nextEvent); err == nil {
+ nextType, _ := nextEvent["type"].(string)
+ if nextType == "message_start" || nextType == "message_delta" || nextType == "message_stop" {
+ i++ // skip the data line
+ continue
+ }
+ if wsToolIndex >= 0 {
+ if idx, ok := nextEvent["index"].(float64); ok && int(idx) == wsToolIndex {
+ i++ // skip the data line
+ continue
+ }
+ }
+ }
+ }
+ resultBuilder.WriteString(line + "\n")
+ hasContent = true
+ } else {
+ resultBuilder.WriteString(line + "\n")
+ if strings.TrimSpace(line) != "" {
+ hasContent = true
+ }
+ }
+ }
+
+ if hasContent {
+ filtered = append(filtered, []byte(resultBuilder.String()))
+ }
+ }
+
+ return filtered
+}
diff --git a/internal/translator/kiro/claude/kiro_claude_tools.go b/internal/translator/kiro/claude/kiro_claude_tools.go
new file mode 100644
index 00000000..d00c7493
--- /dev/null
+++ b/internal/translator/kiro/claude/kiro_claude_tools.go
@@ -0,0 +1,543 @@
+// Package claude provides tool calling support for Kiro to Claude translation.
+// This package handles parsing embedded tool calls, JSON repair, and deduplication.
+package claude
+
+import (
+ "encoding/json"
+ "regexp"
+ "strings"
+
+ "github.com/google/uuid"
+ kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
+ log "github.com/sirupsen/logrus"
+)
+
+// ToolUseState tracks the state of an in-progress tool use during streaming.
+type ToolUseState struct {
+ ToolUseID string
+ Name string
+ InputBuffer strings.Builder
+ IsComplete bool
+ TruncationInfo *TruncationInfo // Truncation detection result (set when complete)
+}
+
+// Pre-compiled regex patterns for performance
+var (
+ // embeddedToolCallPattern matches [Called tool_name with args: {...}] format
+ embeddedToolCallPattern = regexp.MustCompile(`\[Called\s+([A-Za-z0-9_.-]+)\s+with\s+args:\s*`)
+ // trailingCommaPattern matches trailing commas before closing braces/brackets
+ trailingCommaPattern = regexp.MustCompile(`,\s*([}\]])`)
+)
+
+// ParseEmbeddedToolCalls extracts [Called tool_name with args: {...}] format from text.
+// Kiro sometimes embeds tool calls in text content instead of using toolUseEvent.
+// Returns the cleaned text (with tool calls removed) and extracted tool uses.
+func ParseEmbeddedToolCalls(text string, processedIDs map[string]bool) (string, []KiroToolUse) {
+ if !strings.Contains(text, "[Called") {
+ return text, nil
+ }
+
+ var toolUses []KiroToolUse
+ cleanText := text
+
+ // Find all [Called markers
+ matches := embeddedToolCallPattern.FindAllStringSubmatchIndex(text, -1)
+ if len(matches) == 0 {
+ return text, nil
+ }
+
+ // Process matches in reverse order to maintain correct indices
+ for i := len(matches) - 1; i >= 0; i-- {
+ matchStart := matches[i][0]
+ toolNameStart := matches[i][2]
+ toolNameEnd := matches[i][3]
+
+ if toolNameStart < 0 || toolNameEnd < 0 {
+ continue
+ }
+
+ toolName := text[toolNameStart:toolNameEnd]
+
+ // Find the JSON object start (after "with args:")
+ jsonStart := matches[i][1]
+ if jsonStart >= len(text) {
+ continue
+ }
+
+ // Skip whitespace to find the opening brace
+ for jsonStart < len(text) && (text[jsonStart] == ' ' || text[jsonStart] == '\t') {
+ jsonStart++
+ }
+
+ if jsonStart >= len(text) || text[jsonStart] != '{' {
+ continue
+ }
+
+ // Find matching closing bracket
+ jsonEnd := findMatchingBracket(text, jsonStart)
+ if jsonEnd < 0 {
+ continue
+ }
+
+ // Extract JSON and find the closing bracket of [Called ...]
+ jsonStr := text[jsonStart : jsonEnd+1]
+
+ // Find the closing ] after the JSON
+ closingBracket := jsonEnd + 1
+ for closingBracket < len(text) && text[closingBracket] != ']' {
+ closingBracket++
+ }
+ if closingBracket >= len(text) {
+ continue
+ }
+
+ // End index of the full tool call (closing ']' inclusive)
+ matchEnd := closingBracket + 1
+
+ // Repair and parse JSON
+ repairedJSON := RepairJSON(jsonStr)
+ var inputMap map[string]interface{}
+ if err := json.Unmarshal([]byte(repairedJSON), &inputMap); err != nil {
+ log.Debugf("kiro: failed to parse embedded tool call JSON: %v, raw: %s", err, jsonStr)
+ continue
+ }
+
+ // Generate unique tool ID
+ toolUseID := "toolu_" + uuid.New().String()[:12]
+
+ // Check for duplicates using name+input as key
+ dedupeKey := toolName + ":" + repairedJSON
+ if processedIDs != nil {
+ if processedIDs[dedupeKey] {
+ log.Debugf("kiro: skipping duplicate embedded tool call: %s", toolName)
+ // Still remove from text even if duplicate
+ if matchStart >= 0 && matchEnd <= len(cleanText) && matchStart <= matchEnd {
+ cleanText = cleanText[:matchStart] + cleanText[matchEnd:]
+ }
+ continue
+ }
+ processedIDs[dedupeKey] = true
+ }
+
+ toolUses = append(toolUses, KiroToolUse{
+ ToolUseID: toolUseID,
+ Name: toolName,
+ Input: inputMap,
+ })
+
+ log.Infof("kiro: extracted embedded tool call: %s (ID: %s)", toolName, toolUseID)
+
+ // Remove from clean text (index-based removal to avoid deleting the wrong occurrence)
+ if matchStart >= 0 && matchEnd <= len(cleanText) && matchStart <= matchEnd {
+ cleanText = cleanText[:matchStart] + cleanText[matchEnd:]
+ }
+ }
+
+ return cleanText, toolUses
+}
+
+// findMatchingBracket finds the index of the closing brace/bracket that matches
+// the opening one at startPos. Handles nested objects and strings correctly.
+func findMatchingBracket(text string, startPos int) int {
+ if startPos >= len(text) {
+ return -1
+ }
+
+ openChar := text[startPos]
+ var closeChar byte
+ switch openChar {
+ case '{':
+ closeChar = '}'
+ case '[':
+ closeChar = ']'
+ default:
+ return -1
+ }
+
+ depth := 1
+ inString := false
+ escapeNext := false
+
+ for i := startPos + 1; i < len(text); i++ {
+ char := text[i]
+
+ if escapeNext {
+ escapeNext = false
+ continue
+ }
+
+ if char == '\\' && inString {
+ escapeNext = true
+ continue
+ }
+
+ if char == '"' {
+ inString = !inString
+ continue
+ }
+
+ if !inString {
+ if char == openChar {
+ depth++
+ } else if char == closeChar {
+ depth--
+ if depth == 0 {
+ return i
+ }
+ }
+ }
+ }
+
+ return -1
+}
+
+// RepairJSON attempts to fix common JSON issues that may occur in tool call arguments.
+// Conservative repair strategy:
+// 1. First try to parse JSON directly - if valid, return as-is
+// 2. Only attempt repair if parsing fails
+// 3. After repair, validate the result - if still invalid, return original
+func RepairJSON(jsonString string) string {
+ // Handle empty or invalid input
+ if jsonString == "" {
+ return "{}"
+ }
+
+ str := strings.TrimSpace(jsonString)
+ if str == "" {
+ return "{}"
+ }
+
+ // CONSERVATIVE STRATEGY: First try to parse directly
+ var testParse interface{}
+ if err := json.Unmarshal([]byte(str), &testParse); err == nil {
+ log.Debugf("kiro: repairJSON - JSON is already valid, returning unchanged")
+ return str
+ }
+
+ log.Debugf("kiro: repairJSON - JSON parse failed, attempting repair")
+ originalStr := str
+
+ // First, escape unescaped newlines/tabs within JSON string values
+ str = escapeNewlinesInStrings(str)
+ // Remove trailing commas before closing braces/brackets
+ str = trailingCommaPattern.ReplaceAllString(str, "$1")
+
+ // Calculate bracket balance
+ braceCount := 0
+ bracketCount := 0
+ inString := false
+ escape := false
+ lastValidIndex := -1
+
+ for i := 0; i < len(str); i++ {
+ char := str[i]
+
+ if escape {
+ escape = false
+ continue
+ }
+
+ if char == '\\' {
+ escape = true
+ continue
+ }
+
+ if char == '"' {
+ inString = !inString
+ continue
+ }
+
+ if inString {
+ continue
+ }
+
+ switch char {
+ case '{':
+ braceCount++
+ case '}':
+ braceCount--
+ case '[':
+ bracketCount++
+ case ']':
+ bracketCount--
+ }
+
+ if braceCount >= 0 && bracketCount >= 0 {
+ lastValidIndex = i
+ }
+ }
+
+ // If brackets are unbalanced, try to repair
+ if braceCount > 0 || bracketCount > 0 {
+ if lastValidIndex > 0 && lastValidIndex < len(str)-1 {
+ truncated := str[:lastValidIndex+1]
+ // Recount brackets after truncation
+ braceCount = 0
+ bracketCount = 0
+ inString = false
+ escape = false
+ for i := 0; i < len(truncated); i++ {
+ char := truncated[i]
+ if escape {
+ escape = false
+ continue
+ }
+ if char == '\\' {
+ escape = true
+ continue
+ }
+ if char == '"' {
+ inString = !inString
+ continue
+ }
+ if inString {
+ continue
+ }
+ switch char {
+ case '{':
+ braceCount++
+ case '}':
+ braceCount--
+ case '[':
+ bracketCount++
+ case ']':
+ bracketCount--
+ }
+ }
+ str = truncated
+ }
+
+ // Add missing closing brackets
+ for braceCount > 0 {
+ str += "}"
+ braceCount--
+ }
+ for bracketCount > 0 {
+ str += "]"
+ bracketCount--
+ }
+ }
+
+ // Validate repaired JSON
+ if err := json.Unmarshal([]byte(str), &testParse); err != nil {
+ log.Warnf("kiro: repairJSON - repair failed to produce valid JSON, returning original")
+ return originalStr
+ }
+
+ log.Debugf("kiro: repairJSON - successfully repaired JSON")
+ return str
+}
+
+// escapeNewlinesInStrings escapes literal newlines, tabs, and other control characters
+// that appear inside JSON string values.
+func escapeNewlinesInStrings(raw string) string {
+ var result strings.Builder
+ result.Grow(len(raw) + 100)
+
+ inString := false
+ escaped := false
+
+ for i := 0; i < len(raw); i++ {
+ c := raw[i]
+
+ if escaped {
+ result.WriteByte(c)
+ escaped = false
+ continue
+ }
+
+ if c == '\\' && inString {
+ result.WriteByte(c)
+ escaped = true
+ continue
+ }
+
+ if c == '"' {
+ inString = !inString
+ result.WriteByte(c)
+ continue
+ }
+
+ if inString {
+ switch c {
+ case '\n':
+ result.WriteString("\\n")
+ case '\r':
+ result.WriteString("\\r")
+ case '\t':
+ result.WriteString("\\t")
+ default:
+ result.WriteByte(c)
+ }
+ } else {
+ result.WriteByte(c)
+ }
+ }
+
+ return result.String()
+}
+
+// ProcessToolUseEvent handles a toolUseEvent from the Kiro stream.
+// It accumulates input fragments and emits tool_use blocks when complete.
+// Returns events to emit and updated state.
+func ProcessToolUseEvent(event map[string]interface{}, currentToolUse *ToolUseState, processedIDs map[string]bool) ([]KiroToolUse, *ToolUseState) {
+ var toolUses []KiroToolUse
+
+ // Extract from nested toolUseEvent or direct format
+ tu := event
+ if nested, ok := event["toolUseEvent"].(map[string]interface{}); ok {
+ tu = nested
+ }
+
+ toolUseID := kirocommon.GetString(tu, "toolUseId")
+ toolName := kirocommon.GetString(tu, "name")
+ isStop := false
+ if stop, ok := tu["stop"].(bool); ok {
+ isStop = stop
+ }
+
+ // Get input - can be string (fragment) or object (complete)
+ var inputFragment string
+ var inputMap map[string]interface{}
+
+ if inputRaw, ok := tu["input"]; ok {
+ switch v := inputRaw.(type) {
+ case string:
+ inputFragment = v
+ case map[string]interface{}:
+ inputMap = v
+ }
+ }
+
+ // New tool use starting
+ if toolUseID != "" && toolName != "" {
+ if currentToolUse != nil && currentToolUse.ToolUseID != toolUseID {
+ log.Warnf("kiro: interleaved tool use detected - new ID %s arrived while %s in progress, completing previous",
+ toolUseID, currentToolUse.ToolUseID)
+ if !processedIDs[currentToolUse.ToolUseID] {
+ incomplete := KiroToolUse{
+ ToolUseID: currentToolUse.ToolUseID,
+ Name: currentToolUse.Name,
+ }
+ if currentToolUse.InputBuffer.Len() > 0 {
+ raw := currentToolUse.InputBuffer.String()
+ repaired := RepairJSON(raw)
+
+ var input map[string]interface{}
+ if err := json.Unmarshal([]byte(repaired), &input); err != nil {
+ log.Warnf("kiro: failed to parse interleaved tool input: %v, raw: %s", err, raw)
+ input = make(map[string]interface{})
+ }
+ incomplete.Input = input
+ }
+ toolUses = append(toolUses, incomplete)
+ processedIDs[currentToolUse.ToolUseID] = true
+ }
+ currentToolUse = nil
+ }
+
+ if currentToolUse == nil {
+ if processedIDs != nil && processedIDs[toolUseID] {
+ log.Debugf("kiro: skipping duplicate toolUseEvent: %s", toolUseID)
+ return nil, nil
+ }
+
+ currentToolUse = &ToolUseState{
+ ToolUseID: toolUseID,
+ Name: toolName,
+ }
+ log.Infof("kiro: starting new tool use: %s (ID: %s)", toolName, toolUseID)
+ }
+ }
+
+ // Accumulate input fragments
+ if currentToolUse != nil && inputFragment != "" {
+ currentToolUse.InputBuffer.WriteString(inputFragment)
+ log.Debugf("kiro: accumulated input fragment, total length: %d", currentToolUse.InputBuffer.Len())
+ }
+
+ // If complete input object provided directly
+ if currentToolUse != nil && inputMap != nil {
+ inputBytes, _ := json.Marshal(inputMap)
+ currentToolUse.InputBuffer.Reset()
+ currentToolUse.InputBuffer.Write(inputBytes)
+ }
+
+ // Tool use complete
+ if isStop && currentToolUse != nil {
+ fullInput := currentToolUse.InputBuffer.String()
+
+ // Repair and parse the accumulated JSON
+ repairedJSON := RepairJSON(fullInput)
+ var finalInput map[string]interface{}
+ if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil {
+ log.Warnf("kiro: failed to parse accumulated tool input: %v, raw: %s", err, fullInput)
+ finalInput = make(map[string]interface{})
+ }
+
+ // Detect truncation for all tools
+ truncInfo := DetectTruncation(currentToolUse.Name, currentToolUse.ToolUseID, fullInput, finalInput)
+ if truncInfo.IsTruncated {
+ log.Warnf("kiro: TRUNCATION DETECTED for tool %s (ID: %s): type=%s, raw_size=%d bytes",
+ currentToolUse.Name, currentToolUse.ToolUseID, truncInfo.TruncationType, len(fullInput))
+ log.Warnf("kiro: truncation details: %s", truncInfo.ErrorMessage)
+ if len(truncInfo.ParsedFields) > 0 {
+ log.Infof("kiro: partial fields received: %v", truncInfo.ParsedFields)
+ }
+ // Store truncation info in the state for upstream handling
+ currentToolUse.TruncationInfo = &truncInfo
+ } else {
+ log.Infof("kiro: tool use %s input length: %d bytes (no truncation)", currentToolUse.Name, len(fullInput))
+ }
+
+ // Create the tool use with truncation info if applicable
+ toolUse := KiroToolUse{
+ ToolUseID: currentToolUse.ToolUseID,
+ Name: currentToolUse.Name,
+ Input: finalInput,
+ IsTruncated: truncInfo.IsTruncated,
+ TruncationInfo: nil, // Will be set below if truncated
+ }
+ if truncInfo.IsTruncated {
+ toolUse.TruncationInfo = &truncInfo
+ }
+ toolUses = append(toolUses, toolUse)
+
+ if processedIDs != nil {
+ processedIDs[currentToolUse.ToolUseID] = true
+ }
+
+ log.Infof("kiro: completed tool use: %s (ID: %s, truncated: %v)", currentToolUse.Name, currentToolUse.ToolUseID, truncInfo.IsTruncated)
+ return toolUses, nil
+ }
+
+ return toolUses, currentToolUse
+}
+
+// DeduplicateToolUses removes duplicate tool uses based on toolUseId and content.
+func DeduplicateToolUses(toolUses []KiroToolUse) []KiroToolUse {
+ seenIDs := make(map[string]bool)
+ seenContent := make(map[string]bool)
+ var unique []KiroToolUse
+
+ for _, tu := range toolUses {
+ if seenIDs[tu.ToolUseID] {
+ log.Debugf("kiro: removing ID-duplicate tool use: %s (name: %s)", tu.ToolUseID, tu.Name)
+ continue
+ }
+
+ inputJSON, _ := json.Marshal(tu.Input)
+ contentKey := tu.Name + ":" + string(inputJSON)
+
+ if seenContent[contentKey] {
+ log.Debugf("kiro: removing content-duplicate tool use: %s (id: %s)", tu.Name, tu.ToolUseID)
+ continue
+ }
+
+ seenIDs[tu.ToolUseID] = true
+ seenContent[contentKey] = true
+ unique = append(unique, tu)
+ }
+
+ return unique
+}
diff --git a/internal/translator/kiro/claude/kiro_websearch.go b/internal/translator/kiro/claude/kiro_websearch.go
new file mode 100644
index 00000000..b9da3829
--- /dev/null
+++ b/internal/translator/kiro/claude/kiro_websearch.go
@@ -0,0 +1,495 @@
+// Package claude provides web search functionality for Kiro translator.
+// This file implements detection, MCP request/response types, and pure data
+// transformation utilities for web search. SSE event generation, stream analysis,
+// and HTTP I/O logic reside in the executor package (kiro_executor.go).
+package claude
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "github.com/google/uuid"
+ log "github.com/sirupsen/logrus"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+)
+
+// cachedToolDescription stores the dynamically-fetched web_search tool description.
+// Written by the executor via SetWebSearchDescription, read by the translator
+// when building the remote_web_search tool for Kiro API requests.
+var cachedToolDescription atomic.Value // stores string
+
+// GetWebSearchDescription returns the cached web_search tool description,
+// or empty string if not yet fetched. Lock-free via atomic.Value.
+func GetWebSearchDescription() string {
+ if v := cachedToolDescription.Load(); v != nil {
+ return v.(string)
+ }
+ return ""
+}
+
+// SetWebSearchDescription stores the dynamically-fetched web_search tool description.
+// Called by the executor after fetching from MCP tools/list.
+func SetWebSearchDescription(desc string) {
+ cachedToolDescription.Store(desc)
+}
+
+// McpRequest represents a JSON-RPC 2.0 request to Kiro MCP API
+type McpRequest struct {
+ ID string `json:"id"`
+ JSONRPC string `json:"jsonrpc"`
+ Method string `json:"method"`
+ Params McpParams `json:"params"`
+}
+
+// McpParams represents MCP request parameters
+type McpParams struct {
+ Name string `json:"name"`
+ Arguments McpArguments `json:"arguments"`
+}
+
+// McpArgumentsMeta represents the _meta field in MCP arguments
+type McpArgumentsMeta struct {
+ IsValid bool `json:"_isValid"`
+ ActivePath []string `json:"_activePath"`
+ CompletedPaths [][]string `json:"_completedPaths"`
+}
+
+// McpArguments represents MCP request arguments
+type McpArguments struct {
+ Query string `json:"query"`
+ Meta *McpArgumentsMeta `json:"_meta,omitempty"`
+}
+
+// McpResponse represents a JSON-RPC 2.0 response from Kiro MCP API
+type McpResponse struct {
+ Error *McpError `json:"error,omitempty"`
+ ID string `json:"id"`
+ JSONRPC string `json:"jsonrpc"`
+ Result *McpResult `json:"result,omitempty"`
+}
+
+// McpError represents an MCP error
+type McpError struct {
+ Code *int `json:"code,omitempty"`
+ Message *string `json:"message,omitempty"`
+}
+
+// McpResult represents MCP result
+type McpResult struct {
+ Content []McpContent `json:"content"`
+ IsError bool `json:"isError"`
+}
+
+// McpContent represents MCP content item
+type McpContent struct {
+ ContentType string `json:"type"`
+ Text string `json:"text"`
+}
+
+// WebSearchResults represents parsed search results
+type WebSearchResults struct {
+ Results []WebSearchResult `json:"results"`
+ TotalResults *int `json:"totalResults,omitempty"`
+ Query *string `json:"query,omitempty"`
+ Error *string `json:"error,omitempty"`
+}
+
+// WebSearchResult represents a single search result
+type WebSearchResult struct {
+ Title string `json:"title"`
+ URL string `json:"url"`
+ Snippet *string `json:"snippet,omitempty"`
+ PublishedDate *int64 `json:"publishedDate,omitempty"`
+ ID *string `json:"id,omitempty"`
+ Domain *string `json:"domain,omitempty"`
+ MaxVerbatimWordLimit *int `json:"maxVerbatimWordLimit,omitempty"`
+ PublicDomain *bool `json:"publicDomain,omitempty"`
+}
+
+// isWebSearchTool checks if a tool name or type indicates a web_search tool.
+func isWebSearchTool(name, toolType string) bool {
+ return name == "web_search" ||
+ strings.HasPrefix(toolType, "web_search") ||
+ toolType == "web_search_20250305"
+}
+
+// HasWebSearchTool checks if the request contains ONLY a web_search tool.
+// Returns true only if tools array has exactly one tool named "web_search".
+// Only intercept pure web_search requests (single-tool array).
+func HasWebSearchTool(body []byte) bool {
+ tools := gjson.GetBytes(body, "tools")
+ if !tools.IsArray() {
+ return false
+ }
+
+ toolsArray := tools.Array()
+ if len(toolsArray) != 1 {
+ return false
+ }
+
+ // Check if the single tool is web_search
+ tool := toolsArray[0]
+
+ // Check both name and type fields for web_search detection
+ name := strings.ToLower(tool.Get("name").String())
+ toolType := strings.ToLower(tool.Get("type").String())
+
+ return isWebSearchTool(name, toolType)
+}
+
+// ExtractSearchQuery extracts the search query from the request.
+// Reads messages[0].content and removes "Perform a web search for the query: " prefix.
+func ExtractSearchQuery(body []byte) string {
+ messages := gjson.GetBytes(body, "messages")
+ if !messages.IsArray() || len(messages.Array()) == 0 {
+ return ""
+ }
+
+ firstMsg := messages.Array()[0]
+ content := firstMsg.Get("content")
+
+ var text string
+ if content.IsArray() {
+ // Array format: [{"type": "text", "text": "..."}]
+ for _, block := range content.Array() {
+ if block.Get("type").String() == "text" {
+ text = block.Get("text").String()
+ break
+ }
+ }
+ } else {
+ // String format
+ text = content.String()
+ }
+
+ // Remove prefix "Perform a web search for the query: "
+ const prefix = "Perform a web search for the query: "
+ if strings.HasPrefix(text, prefix) {
+ text = text[len(prefix):]
+ }
+
+ return strings.TrimSpace(text)
+}
+
+// generateRandomID8 generates an 8-character random lowercase alphanumeric string
+func generateRandomID8() string {
+ u := uuid.New()
+ return strings.ToLower(strings.ReplaceAll(u.String(), "-", "")[:8])
+}
+
+// CreateMcpRequest creates an MCP request for web search.
+// Returns (toolUseID, McpRequest)
+// ID format: web_search_tooluse_{22 random}_{timestamp_millis}_{8 random}
+func CreateMcpRequest(query string) (string, *McpRequest) {
+ random22 := GenerateToolUseID()
+ timestamp := time.Now().UnixMilli()
+ random8 := generateRandomID8()
+
+ requestID := fmt.Sprintf("web_search_tooluse_%s_%d_%s", random22, timestamp, random8)
+
+ // tool_use_id format: srvtoolu_{32 hex chars}
+ toolUseID := "srvtoolu_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:32]
+
+ request := &McpRequest{
+ ID: requestID,
+ JSONRPC: "2.0",
+ Method: "tools/call",
+ Params: McpParams{
+ Name: "web_search",
+ Arguments: McpArguments{
+ Query: query,
+ Meta: &McpArgumentsMeta{
+ IsValid: true,
+ ActivePath: []string{"query"},
+ CompletedPaths: [][]string{{"query"}},
+ },
+ },
+ },
+ }
+
+ return toolUseID, request
+}
+
+// GenerateToolUseID generates a Kiro-style tool use ID (base62-like UUID)
+func GenerateToolUseID() string {
+ return strings.ReplaceAll(uuid.New().String(), "-", "")[:22]
+}
+
+// ReplaceWebSearchToolDescription replaces the web_search tool description with
+// a minimal version that allows re-search without the restrictive "do not search
+// non-coding topics" instruction from the original Kiro tools/list response.
+// This keeps the tool available so the model can request additional searches.
+func ReplaceWebSearchToolDescription(body []byte) ([]byte, error) {
+ tools := gjson.GetBytes(body, "tools")
+ if !tools.IsArray() {
+ return body, nil
+ }
+
+ var updated []json.RawMessage
+ for _, tool := range tools.Array() {
+ name := strings.ToLower(tool.Get("name").String())
+ toolType := strings.ToLower(tool.Get("type").String())
+
+ if isWebSearchTool(name, toolType) {
+ // Replace with a minimal web_search tool definition
+ minimalTool := map[string]interface{}{
+ "name": "web_search",
+ "description": "Search the web for information. Use this when the previous search results are insufficient or when you need additional information on a different aspect of the query. Provide a refined or different search query.",
+ "input_schema": map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{
+ "query": map[string]interface{}{
+ "type": "string",
+ "description": "The search query to execute",
+ },
+ },
+ "required": []string{"query"},
+ "additionalProperties": false,
+ },
+ }
+ minimalJSON, err := json.Marshal(minimalTool)
+ if err != nil {
+ return body, fmt.Errorf("failed to marshal minimal tool: %w", err)
+ }
+ updated = append(updated, json.RawMessage(minimalJSON))
+ } else {
+ updated = append(updated, json.RawMessage(tool.Raw))
+ }
+ }
+
+ updatedJSON, err := json.Marshal(updated)
+ if err != nil {
+ return body, fmt.Errorf("failed to marshal updated tools: %w", err)
+ }
+ result, err := sjson.SetRawBytes(body, "tools", updatedJSON)
+ if err != nil {
+ return body, fmt.Errorf("failed to set updated tools: %w", err)
+ }
+
+ return result, nil
+}
+
+// FormatSearchContextPrompt formats search results as a structured text block
+// for injection into the system prompt.
+func FormatSearchContextPrompt(query string, results *WebSearchResults) string {
+ var sb strings.Builder
+ sb.WriteString(fmt.Sprintf("[Web Search Results for \"%s\"]\n", query))
+
+ if results != nil && len(results.Results) > 0 {
+ for i, r := range results.Results {
+ sb.WriteString(fmt.Sprintf("%d. %s - %s\n", i+1, r.Title, r.URL))
+ if r.Snippet != nil && *r.Snippet != "" {
+ snippet := *r.Snippet
+ if len(snippet) > 500 {
+ snippet = snippet[:500] + "..."
+ }
+ sb.WriteString(fmt.Sprintf(" %s\n", snippet))
+ }
+ }
+ } else {
+ sb.WriteString("No results found.\n")
+ }
+
+ sb.WriteString("[End Web Search Results]")
+ return sb.String()
+}
+
+// FormatToolResultText formats search results as JSON text for the toolResults content field.
+// This matches the format observed in Kiro IDE HAR captures.
+func FormatToolResultText(results *WebSearchResults) string {
+ if results == nil || len(results.Results) == 0 {
+ return "No search results found."
+ }
+
+ text := fmt.Sprintf("Found %d search result(s):\n\n", len(results.Results))
+ resultJSON, err := json.MarshalIndent(results.Results, "", " ")
+ if err != nil {
+ return text + "Error formatting results."
+ }
+ return text + string(resultJSON)
+}
+
+// InjectToolResultsClaude modifies a Claude-format JSON payload to append
+// tool_use (assistant) and tool_result (user) messages to the messages array.
+// BuildKiroPayload correctly translates:
+// - assistant tool_use → KiroAssistantResponseMessage.toolUses
+// - user tool_result → KiroUserInputMessageContext.toolResults
+//
+// This produces the exact same GAR request format as the Kiro IDE (HAR captures).
+// IMPORTANT: The web_search tool must remain in the "tools" array for this to work.
+// Use ReplaceWebSearchToolDescription to keep the tool available with a minimal description.
+func InjectToolResultsClaude(claudePayload []byte, toolUseId, query string, results *WebSearchResults) ([]byte, error) {
+ var payload map[string]interface{}
+ if err := json.Unmarshal(claudePayload, &payload); err != nil {
+ return claudePayload, fmt.Errorf("failed to parse claude payload: %w", err)
+ }
+
+ messages, _ := payload["messages"].([]interface{})
+
+ // 1. Append assistant message with tool_use (matches HAR: assistantResponseMessage.toolUses)
+ assistantMsg := map[string]interface{}{
+ "role": "assistant",
+ "content": []interface{}{
+ map[string]interface{}{
+ "type": "tool_use",
+ "id": toolUseId,
+ "name": "web_search",
+ "input": map[string]interface{}{"query": query},
+ },
+ },
+ }
+ messages = append(messages, assistantMsg)
+
+ // 2. Append user message with tool_result + search behavior instructions.
+ // NOTE: We embed search instructions HERE (not in system prompt) because
+ // BuildKiroPayload clears the system prompt when len(history) > 0,
+ // which is always true after injecting assistant + user messages.
+ now := time.Now()
+ searchGuidance := fmt.Sprintf(`
+Current date: %s (%s)
+
+IMPORTANT: Evaluate the search results above carefully. If the results are:
+- Mostly spam, SEO junk, or unrelated websites
+- Missing actual information about the query topic
+- Outdated or not matching the requested time frame
+
+Then you MUST use the web_search tool again with a refined query. Try:
+- Rephrasing in English for better coverage
+- Using more specific keywords
+- Adding date context
+
+Do NOT apologize for bad results without first attempting a re-search.
+`, now.Format("January 2, 2006"), now.Format("Monday"))
+
+ userMsg := map[string]interface{}{
+ "role": "user",
+ "content": []interface{}{
+ map[string]interface{}{
+ "type": "tool_result",
+ "tool_use_id": toolUseId,
+ "content": FormatToolResultText(results),
+ },
+ map[string]interface{}{
+ "type": "text",
+ "text": searchGuidance,
+ },
+ },
+ }
+ messages = append(messages, userMsg)
+
+ payload["messages"] = messages
+
+ result, err := json.Marshal(payload)
+ if err != nil {
+ return claudePayload, fmt.Errorf("failed to marshal updated payload: %w", err)
+ }
+
+ log.Infof("kiro/websearch: injected tool_use+tool_result (toolUseId=%s, messages=%d)",
+ toolUseId, len(messages))
+
+ return result, nil
+}
+
+// InjectSearchIndicatorsInResponse prepends server_tool_use + web_search_tool_result
+// content blocks into a non-streaming Claude JSON response. Claude Code counts
+// server_tool_use blocks to display "Did X searches in Ys".
+//
+// Input response: {"content": [{"type":"text","text":"..."}], ...}
+// Output response: {"content": [{"type":"server_tool_use",...}, {"type":"web_search_tool_result",...}, {"type":"text","text":"..."}], ...}
+func InjectSearchIndicatorsInResponse(responsePayload []byte, searches []SearchIndicator) ([]byte, error) {
+ if len(searches) == 0 {
+ return responsePayload, nil
+ }
+
+ var resp map[string]interface{}
+ if err := json.Unmarshal(responsePayload, &resp); err != nil {
+ return responsePayload, fmt.Errorf("failed to parse response: %w", err)
+ }
+
+ existingContent, _ := resp["content"].([]interface{})
+
+ // Build new content: search indicators first, then existing content
+ newContent := make([]interface{}, 0, len(searches)*2+len(existingContent))
+
+ for _, s := range searches {
+ // server_tool_use block
+ newContent = append(newContent, map[string]interface{}{
+ "type": "server_tool_use",
+ "id": s.ToolUseID,
+ "name": "web_search",
+ "input": map[string]interface{}{"query": s.Query},
+ })
+
+ // web_search_tool_result block
+ searchContent := make([]map[string]interface{}, 0)
+ if s.Results != nil {
+ for _, r := range s.Results.Results {
+ snippet := ""
+ if r.Snippet != nil {
+ snippet = *r.Snippet
+ }
+ searchContent = append(searchContent, map[string]interface{}{
+ "type": "web_search_result",
+ "title": r.Title,
+ "url": r.URL,
+ "encrypted_content": snippet,
+ "page_age": nil,
+ })
+ }
+ }
+ newContent = append(newContent, map[string]interface{}{
+ "type": "web_search_tool_result",
+ "tool_use_id": s.ToolUseID,
+ "content": searchContent,
+ })
+ }
+
+ // Append existing content blocks
+ newContent = append(newContent, existingContent...)
+ resp["content"] = newContent
+
+ result, err := json.Marshal(resp)
+ if err != nil {
+ return responsePayload, fmt.Errorf("failed to marshal response: %w", err)
+ }
+
+ log.Infof("kiro/websearch: injected %d search indicator(s) into non-stream response", len(searches))
+ return result, nil
+}
+
+// SearchIndicator holds the data for one search operation to inject into a response.
+type SearchIndicator struct {
+ ToolUseID string
+ Query string
+ Results *WebSearchResults
+}
+
+// BuildMcpEndpoint constructs the MCP endpoint URL for the given AWS region.
+// Centralizes the URL pattern used by both handleWebSearch and handleWebSearchStream.
+func BuildMcpEndpoint(region string) string {
+ return fmt.Sprintf("https://q.%s.amazonaws.com/mcp", region)
+}
+
+// ParseSearchResults extracts WebSearchResults from MCP response
+func ParseSearchResults(response *McpResponse) *WebSearchResults {
+ if response == nil || response.Result == nil || len(response.Result.Content) == 0 {
+ return nil
+ }
+
+ content := response.Result.Content[0]
+ if content.ContentType != "text" {
+ return nil
+ }
+
+ var results WebSearchResults
+ if err := json.Unmarshal([]byte(content.Text), &results); err != nil {
+ log.Warnf("kiro/websearch: failed to parse search results: %v", err)
+ return nil
+ }
+
+ return &results
+}
diff --git a/internal/translator/kiro/claude/tool_compression.go b/internal/translator/kiro/claude/tool_compression.go
new file mode 100644
index 00000000..7d4a424e
--- /dev/null
+++ b/internal/translator/kiro/claude/tool_compression.go
@@ -0,0 +1,191 @@
+// Package claude provides tool compression functionality for Kiro translator.
+// This file implements dynamic tool compression to reduce tool payload size
+// when it exceeds the target threshold, preventing 500 errors from Kiro API.
+package claude
+
+import (
+ "encoding/json"
+ "unicode/utf8"
+
+ kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
+ log "github.com/sirupsen/logrus"
+)
+
+// calculateToolsSize calculates the JSON serialized size of the tools list.
+// Returns the size in bytes.
+func calculateToolsSize(tools []KiroToolWrapper) int {
+ if len(tools) == 0 {
+ return 0
+ }
+ data, err := json.Marshal(tools)
+ if err != nil {
+ log.Warnf("kiro: failed to marshal tools for size calculation: %v", err)
+ return 0
+ }
+ return len(data)
+}
+
+// simplifyInputSchema simplifies the input_schema by keeping only essential fields:
+// type, enum, required. Recursively processes nested properties.
+func simplifyInputSchema(schema interface{}) interface{} {
+ if schema == nil {
+ return nil
+ }
+
+ schemaMap, ok := schema.(map[string]interface{})
+ if !ok {
+ return schema
+ }
+
+ simplified := make(map[string]interface{})
+
+ // Keep essential fields
+ if t, ok := schemaMap["type"]; ok {
+ simplified["type"] = t
+ }
+ if enum, ok := schemaMap["enum"]; ok {
+ simplified["enum"] = enum
+ }
+ if required, ok := schemaMap["required"]; ok {
+ simplified["required"] = required
+ }
+
+ // Recursively process properties
+ if properties, ok := schemaMap["properties"].(map[string]interface{}); ok {
+ simplifiedProps := make(map[string]interface{})
+ for key, value := range properties {
+ simplifiedProps[key] = simplifyInputSchema(value)
+ }
+ simplified["properties"] = simplifiedProps
+ }
+
+ // Process items for array types
+ if items, ok := schemaMap["items"]; ok {
+ simplified["items"] = simplifyInputSchema(items)
+ }
+
+ // Process additionalProperties if present
+ if additionalProps, ok := schemaMap["additionalProperties"]; ok {
+ simplified["additionalProperties"] = simplifyInputSchema(additionalProps)
+ }
+
+ // Process anyOf, oneOf, allOf
+ for _, key := range []string{"anyOf", "oneOf", "allOf"} {
+ if arr, ok := schemaMap[key].([]interface{}); ok {
+ simplifiedArr := make([]interface{}, len(arr))
+ for i, item := range arr {
+ simplifiedArr[i] = simplifyInputSchema(item)
+ }
+ simplified[key] = simplifiedArr
+ }
+ }
+
+ return simplified
+}
+
+// compressToolDescription compresses a description to the target length.
+// Ensures the result is at least MinToolDescriptionLength characters.
+// Uses UTF-8 safe truncation.
+func compressToolDescription(description string, targetLength int) string {
+ if targetLength < kirocommon.MinToolDescriptionLength {
+ targetLength = kirocommon.MinToolDescriptionLength
+ }
+
+ if len(description) <= targetLength {
+ return description
+ }
+
+ // Find a safe truncation point (UTF-8 boundary)
+ truncLen := targetLength - 3 // Leave room for "..."
+
+ // Ensure we don't cut in the middle of a UTF-8 character
+ for truncLen > 0 && !utf8.RuneStart(description[truncLen]) {
+ truncLen--
+ }
+
+ if truncLen <= 0 {
+ return description[:kirocommon.MinToolDescriptionLength]
+ }
+
+ return description[:truncLen] + "..."
+}
+
+// compressToolsIfNeeded compresses tools if their total size exceeds the target threshold.
+// Compression strategy:
+// 1. First, check if compression is needed (size > ToolCompressionTargetSize)
+// 2. Step 1: Simplify input_schema (keep only type/enum/required)
+// 3. Step 2: Proportionally compress descriptions (minimum MinToolDescriptionLength chars)
+// Returns the compressed tools list.
+func compressToolsIfNeeded(tools []KiroToolWrapper) []KiroToolWrapper {
+ if len(tools) == 0 {
+ return tools
+ }
+
+ originalSize := calculateToolsSize(tools)
+ if originalSize <= kirocommon.ToolCompressionTargetSize {
+ log.Debugf("kiro: tools size %d bytes is within target %d bytes, no compression needed",
+ originalSize, kirocommon.ToolCompressionTargetSize)
+ return tools
+ }
+
+ log.Infof("kiro: tools size %d bytes exceeds target %d bytes, starting compression",
+ originalSize, kirocommon.ToolCompressionTargetSize)
+
+ // Create a copy of tools to avoid modifying the original
+ compressedTools := make([]KiroToolWrapper, len(tools))
+ for i, tool := range tools {
+ compressedTools[i] = KiroToolWrapper{
+ ToolSpecification: KiroToolSpecification{
+ Name: tool.ToolSpecification.Name,
+ Description: tool.ToolSpecification.Description,
+ InputSchema: KiroInputSchema{JSON: tool.ToolSpecification.InputSchema.JSON},
+ },
+ }
+ }
+
+ // Step 1: Simplify input_schema
+ for i := range compressedTools {
+ compressedTools[i].ToolSpecification.InputSchema.JSON =
+ simplifyInputSchema(compressedTools[i].ToolSpecification.InputSchema.JSON)
+ }
+
+ sizeAfterSchemaSimplification := calculateToolsSize(compressedTools)
+ log.Debugf("kiro: size after schema simplification: %d bytes (reduced by %d bytes)",
+ sizeAfterSchemaSimplification, originalSize-sizeAfterSchemaSimplification)
+
+ // Check if we're within target after schema simplification
+ if sizeAfterSchemaSimplification <= kirocommon.ToolCompressionTargetSize {
+ log.Infof("kiro: compression complete after schema simplification, final size: %d bytes",
+ sizeAfterSchemaSimplification)
+ return compressedTools
+ }
+
+ // Step 2: Compress descriptions proportionally
+ sizeToReduce := float64(sizeAfterSchemaSimplification - kirocommon.ToolCompressionTargetSize)
+ var totalDescLen float64
+ for _, tool := range compressedTools {
+ totalDescLen += float64(len(tool.ToolSpecification.Description))
+ }
+
+ if totalDescLen > 0 {
+ // Assume size reduction comes primarily from descriptions.
+ keepRatio := 1.0 - (sizeToReduce / totalDescLen)
+ if keepRatio > 1.0 {
+ keepRatio = 1.0
+ } else if keepRatio < 0 {
+ keepRatio = 0
+ }
+
+ for i := range compressedTools {
+ desc := compressedTools[i].ToolSpecification.Description
+ targetLen := int(float64(len(desc)) * keepRatio)
+ compressedTools[i].ToolSpecification.Description = compressToolDescription(desc, targetLen)
+ }
+ }
+
+ finalSize := calculateToolsSize(compressedTools)
+ log.Infof("kiro: compression complete, original: %d bytes, final: %d bytes (%.1f%% reduction)",
+ originalSize, finalSize, float64(originalSize-finalSize)/float64(originalSize)*100)
+
+ return compressedTools
+}
diff --git a/internal/translator/kiro/claude/truncation_detector.go b/internal/translator/kiro/claude/truncation_detector.go
new file mode 100644
index 00000000..b05ec11a
--- /dev/null
+++ b/internal/translator/kiro/claude/truncation_detector.go
@@ -0,0 +1,517 @@
+// Package claude provides truncation detection for Kiro tool call responses.
+// When Kiro API reaches its output token limit, tool call JSON may be truncated,
+// resulting in incomplete or unparseable input parameters.
+package claude
+
+import (
+ "encoding/json"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
+)
+
+// TruncationInfo contains details about detected truncation in a tool use event.
+type TruncationInfo struct {
+ IsTruncated bool // Whether truncation was detected
+ TruncationType string // Type of truncation detected
+ ToolName string // Name of the truncated tool
+ ToolUseID string // ID of the truncated tool use
+ RawInput string // The raw (possibly truncated) input string
+ ParsedFields map[string]string // Fields that were successfully parsed before truncation
+ ErrorMessage string // Human-readable error message
+}
+
+// TruncationType constants for different truncation scenarios
+const (
+ TruncationTypeNone = "" // No truncation detected
+ TruncationTypeEmptyInput = "empty_input" // No input data received at all
+ TruncationTypeInvalidJSON = "invalid_json" // JSON is syntactically invalid (truncated mid-value)
+ TruncationTypeMissingFields = "missing_fields" // JSON parsed but critical fields are missing
+ TruncationTypeIncompleteString = "incomplete_string" // String value was cut off mid-content
+)
+
+// KnownWriteTools lists tool names that typically write content and have a "content" field.
+// These tools are checked for content field truncation specifically.
+var KnownWriteTools = map[string]bool{
+ "Write": true,
+ "write_to_file": true,
+ "fsWrite": true,
+ "create_file": true,
+ "edit_file": true,
+ "apply_diff": true,
+ "str_replace_editor": true,
+ "insert": true,
+}
+
+// KnownCommandTools lists tool names that execute commands.
+var KnownCommandTools = map[string]bool{
+ "Bash": true,
+ "execute": true,
+ "run_command": true,
+ "shell": true,
+ "terminal": true,
+ "execute_python": true,
+}
+
+// RequiredFieldsByTool maps tool names to their required fields.
+// If any of these fields are missing, the tool input is considered truncated.
+var RequiredFieldsByTool = map[string][]string{
+ "Write": {"file_path", "content"},
+ "write_to_file": {"path", "content"},
+ "fsWrite": {"path", "content"},
+ "create_file": {"path", "content"},
+ "edit_file": {"path"},
+ "apply_diff": {"path", "diff"},
+ "str_replace_editor": {"path", "old_str", "new_str"},
+ "Bash": {"command"},
+ "execute": {"command"},
+ "run_command": {"command"},
+}
+
+// DetectTruncation checks if the tool use input appears to be truncated.
+// It returns detailed information about the truncation status and type.
+func DetectTruncation(toolName, toolUseID, rawInput string, parsedInput map[string]interface{}) TruncationInfo {
+ info := TruncationInfo{
+ ToolName: toolName,
+ ToolUseID: toolUseID,
+ RawInput: rawInput,
+ ParsedFields: make(map[string]string),
+ }
+
+ // Scenario 1: Empty input buffer - no data received at all
+ if strings.TrimSpace(rawInput) == "" {
+ info.IsTruncated = true
+ info.TruncationType = TruncationTypeEmptyInput
+ info.ErrorMessage = "Tool input was completely empty - API response may have been truncated before tool parameters were transmitted"
+ log.Warnf("kiro: truncation detected [%s] for tool %s (ID: %s): empty input buffer",
+ info.TruncationType, toolName, toolUseID)
+ return info
+ }
+
+ // Scenario 2: JSON parse failure - syntactically invalid JSON
+ if parsedInput == nil || len(parsedInput) == 0 {
+ // Check if the raw input looks like truncated JSON
+ if looksLikeTruncatedJSON(rawInput) {
+ info.IsTruncated = true
+ info.TruncationType = TruncationTypeInvalidJSON
+ info.ParsedFields = extractPartialFields(rawInput)
+ info.ErrorMessage = buildTruncationErrorMessage(toolName, info.TruncationType, info.ParsedFields, rawInput)
+ log.Warnf("kiro: truncation detected [%s] for tool %s (ID: %s): JSON parse failed, raw length=%d bytes",
+ info.TruncationType, toolName, toolUseID, len(rawInput))
+ return info
+ }
+ }
+
+ // Scenario 3: JSON parsed but critical fields are missing
+ if parsedInput != nil {
+ requiredFields, hasRequirements := RequiredFieldsByTool[toolName]
+ if hasRequirements {
+ missingFields := findMissingRequiredFields(parsedInput, requiredFields)
+ if len(missingFields) > 0 {
+ info.IsTruncated = true
+ info.TruncationType = TruncationTypeMissingFields
+ info.ParsedFields = extractParsedFieldNames(parsedInput)
+ info.ErrorMessage = buildMissingFieldsErrorMessage(toolName, missingFields, info.ParsedFields)
+ log.Warnf("kiro: truncation detected [%s] for tool %s (ID: %s): missing required fields: %v",
+ info.TruncationType, toolName, toolUseID, missingFields)
+ return info
+ }
+ }
+
+ // Scenario 4: Check for incomplete string values (very short content for write tools)
+ if isWriteTool(toolName) {
+ if contentTruncation := detectContentTruncation(parsedInput, rawInput); contentTruncation != "" {
+ info.IsTruncated = true
+ info.TruncationType = TruncationTypeIncompleteString
+ info.ParsedFields = extractParsedFieldNames(parsedInput)
+ info.ErrorMessage = contentTruncation
+ log.Warnf("kiro: truncation detected [%s] for tool %s (ID: %s): %s",
+ info.TruncationType, toolName, toolUseID, contentTruncation)
+ return info
+ }
+ }
+ }
+
+ // No truncation detected
+ info.IsTruncated = false
+ info.TruncationType = TruncationTypeNone
+ return info
+}
+
+// looksLikeTruncatedJSON checks if the raw string appears to be truncated JSON.
+func looksLikeTruncatedJSON(raw string) bool {
+ trimmed := strings.TrimSpace(raw)
+ if trimmed == "" {
+ return false
+ }
+
+ // Must start with { to be considered JSON
+ if !strings.HasPrefix(trimmed, "{") {
+ return false
+ }
+
+ // Count brackets to detect imbalance
+ openBraces := strings.Count(trimmed, "{")
+ closeBraces := strings.Count(trimmed, "}")
+ openBrackets := strings.Count(trimmed, "[")
+ closeBrackets := strings.Count(trimmed, "]")
+
+ // Bracket imbalance suggests truncation
+ if openBraces > closeBraces || openBrackets > closeBrackets {
+ return true
+ }
+
+ // Check for obvious truncation patterns
+ // - Ends with a quote but no closing brace
+ // - Ends with a colon (mid key-value)
+ // - Ends with a comma (mid object/array)
+ lastChar := trimmed[len(trimmed)-1]
+ if lastChar != '}' && lastChar != ']' {
+ // Check if it's not a complete simple value
+ if lastChar == '"' || lastChar == ':' || lastChar == ',' {
+ return true
+ }
+ }
+
+ // Check for unclosed strings (odd number of unescaped quotes)
+ inString := false
+ escaped := false
+ for i := 0; i < len(trimmed); i++ {
+ c := trimmed[i]
+ if escaped {
+ escaped = false
+ continue
+ }
+ if c == '\\' {
+ escaped = true
+ continue
+ }
+ if c == '"' {
+ inString = !inString
+ }
+ }
+ if inString {
+ return true // Unclosed string
+ }
+
+ return false
+}
+
+// extractPartialFields attempts to extract any field names from malformed JSON.
+// This helps provide context about what was received before truncation.
+func extractPartialFields(raw string) map[string]string {
+ fields := make(map[string]string)
+
+ // Simple pattern matching for "key": "value" or "key": value patterns
+ // This works even with truncated JSON
+ trimmed := strings.TrimSpace(raw)
+ if !strings.HasPrefix(trimmed, "{") {
+ return fields
+ }
+
+ // Remove opening brace
+ content := strings.TrimPrefix(trimmed, "{")
+
+ // Split by comma (rough parsing)
+ parts := strings.Split(content, ",")
+ for _, part := range parts {
+ part = strings.TrimSpace(part)
+ if colonIdx := strings.Index(part, ":"); colonIdx > 0 {
+ key := strings.TrimSpace(part[:colonIdx])
+ key = strings.Trim(key, `"`)
+ value := strings.TrimSpace(part[colonIdx+1:])
+
+ // Truncate long values for display
+ if len(value) > 50 {
+ value = value[:50] + "..."
+ }
+ fields[key] = value
+ }
+ }
+
+ return fields
+}
+
+// extractParsedFieldNames returns the field names from a successfully parsed map.
+func extractParsedFieldNames(parsed map[string]interface{}) map[string]string {
+ fields := make(map[string]string)
+ for key, val := range parsed {
+ switch v := val.(type) {
+ case string:
+ if len(v) > 50 {
+ fields[key] = v[:50] + "..."
+ } else {
+ fields[key] = v
+ }
+ case nil:
+ fields[key] = ""
+ default:
+ // For complex types, just indicate presence
+ fields[key] = ""
+ }
+ }
+ return fields
+}
+
+// findMissingRequiredFields checks which required fields are missing from the parsed input.
+func findMissingRequiredFields(parsed map[string]interface{}, required []string) []string {
+ var missing []string
+ for _, field := range required {
+ if _, exists := parsed[field]; !exists {
+ missing = append(missing, field)
+ }
+ }
+ return missing
+}
+
+// isWriteTool checks if the tool is a known write/file operation tool.
+func isWriteTool(toolName string) bool {
+ return KnownWriteTools[toolName]
+}
+
+// detectContentTruncation checks if the content field appears truncated for write tools.
+func detectContentTruncation(parsed map[string]interface{}, rawInput string) string {
+ // Check for content field
+ content, hasContent := parsed["content"]
+ if !hasContent {
+ return ""
+ }
+
+ contentStr, isString := content.(string)
+ if !isString {
+ return ""
+ }
+
+ // Heuristic: if raw input is very large but content is suspiciously short,
+ // it might indicate truncation during JSON repair
+ if len(rawInput) > 1000 && len(contentStr) < 100 {
+ return "content field appears suspiciously short compared to raw input size"
+ }
+
+ // Check for code blocks that appear to be cut off
+ if strings.Contains(contentStr, "```") {
+ openFences := strings.Count(contentStr, "```")
+ if openFences%2 != 0 {
+ return "content contains unclosed code fence (```) suggesting truncation"
+ }
+ }
+
+ return ""
+}
+
+// buildTruncationErrorMessage creates a human-readable error message for truncation.
+func buildTruncationErrorMessage(toolName, truncationType string, parsedFields map[string]string, rawInput string) string {
+ var sb strings.Builder
+ sb.WriteString("Tool input was truncated by the API. ")
+
+ switch truncationType {
+ case TruncationTypeEmptyInput:
+ sb.WriteString("No input data was received.")
+ case TruncationTypeInvalidJSON:
+ sb.WriteString("JSON was cut off mid-transmission. ")
+ if len(parsedFields) > 0 {
+ sb.WriteString("Partial fields received: ")
+ first := true
+ for k := range parsedFields {
+ if !first {
+ sb.WriteString(", ")
+ }
+ sb.WriteString(k)
+ first = false
+ }
+ }
+ case TruncationTypeMissingFields:
+ sb.WriteString("Required fields are missing from the input.")
+ case TruncationTypeIncompleteString:
+ sb.WriteString("Content appears to be shortened or incomplete.")
+ }
+
+ sb.WriteString(" Received ")
+ sb.WriteString(string(rune(len(rawInput))))
+ sb.WriteString(" bytes. Please retry with smaller content chunks.")
+
+ return sb.String()
+}
+
+// buildMissingFieldsErrorMessage creates an error message for missing required fields.
+func buildMissingFieldsErrorMessage(toolName string, missingFields []string, parsedFields map[string]string) string {
+ var sb strings.Builder
+ sb.WriteString("Tool '")
+ sb.WriteString(toolName)
+ sb.WriteString("' is missing required fields: ")
+ sb.WriteString(strings.Join(missingFields, ", "))
+ sb.WriteString(". Fields received: ")
+
+ first := true
+ for k := range parsedFields {
+ if !first {
+ sb.WriteString(", ")
+ }
+ sb.WriteString(k)
+ first = false
+ }
+
+ sb.WriteString(". This usually indicates the API response was truncated.")
+ return sb.String()
+}
+
+// IsTruncated is a convenience function to check if a tool use appears truncated.
+func IsTruncated(toolName, rawInput string, parsedInput map[string]interface{}) bool {
+ info := DetectTruncation(toolName, "", rawInput, parsedInput)
+ return info.IsTruncated
+}
+
+// GetTruncationSummary returns a short summary string for logging.
+func GetTruncationSummary(info TruncationInfo) string {
+ if !info.IsTruncated {
+ return ""
+ }
+
+ result, _ := json.Marshal(map[string]interface{}{
+ "tool": info.ToolName,
+ "type": info.TruncationType,
+ "parsed_fields": info.ParsedFields,
+ "raw_input_size": len(info.RawInput),
+ })
+ return string(result)
+}
+
+// SoftFailureMessage contains the message structure for a truncation soft failure.
+// This is returned to Claude as a tool_result to guide retry behavior.
+type SoftFailureMessage struct {
+ Status string // "incomplete" - not an error, just incomplete
+ Reason string // Why the tool call was incomplete
+ Guidance []string // Step-by-step retry instructions
+ Context string // Any context about what was received
+ MaxLineHint int // Suggested maximum lines per chunk
+}
+
+// BuildSoftFailureMessage creates a structured message for Claude when truncation is detected.
+// This follows the "soft failure" pattern:
+// - For Claude: Clear explanation of what happened and how to fix
+// - For User: Hidden or minimized (appears as normal processing)
+//
+// Key principle: "Conclusion First"
+// 1. First state what happened (incomplete)
+// 2. Then explain how to fix (chunked approach)
+// 3. Provide specific guidance (line limits)
+func BuildSoftFailureMessage(info TruncationInfo) SoftFailureMessage {
+ msg := SoftFailureMessage{
+ Status: "incomplete",
+ MaxLineHint: 300, // Conservative default
+ }
+
+ // Build reason based on truncation type
+ switch info.TruncationType {
+ case TruncationTypeEmptyInput:
+ msg.Reason = "Your tool call was too large and the input was completely lost during transmission."
+ msg.MaxLineHint = 200
+ case TruncationTypeInvalidJSON:
+ msg.Reason = "Your tool call was truncated mid-transmission, resulting in incomplete JSON."
+ msg.MaxLineHint = 250
+ case TruncationTypeMissingFields:
+ msg.Reason = "Your tool call was partially received but critical fields were cut off."
+ msg.MaxLineHint = 300
+ case TruncationTypeIncompleteString:
+ msg.Reason = "Your tool call content was truncated - the full content did not arrive."
+ msg.MaxLineHint = 350
+ default:
+ msg.Reason = "Your tool call was truncated by the API due to output size limits."
+ }
+
+ // Build context from parsed fields
+ if len(info.ParsedFields) > 0 {
+ var parts []string
+ for k, v := range info.ParsedFields {
+ if len(v) > 30 {
+ v = v[:30] + "..."
+ }
+ parts = append(parts, k+"="+v)
+ }
+ msg.Context = "Received partial data: " + strings.Join(parts, ", ")
+ }
+
+ // Build retry guidance - CRITICAL: Conclusion first approach
+ msg.Guidance = []string{
+ "CONCLUSION: Split your output into smaller chunks and retry.",
+ "",
+ "REQUIRED APPROACH:",
+ "1. For file writes: Write in chunks of ~" + formatInt(msg.MaxLineHint) + " lines maximum",
+ "2. For new files: First create with initial chunk, then append remaining sections",
+ "3. For edits: Make surgical, targeted changes - avoid rewriting entire files",
+ "",
+ "EXAMPLE (writing a 600-line file):",
+ " - Step 1: Write lines 1-300 (create file)",
+ " - Step 2: Append lines 301-600 (extend file)",
+ "",
+ "DO NOT attempt to write the full content again in a single call.",
+ "The API has a hard output limit that cannot be bypassed.",
+ }
+
+ return msg
+}
+
+// formatInt converts an integer to string (helper to avoid strconv import)
+func formatInt(n int) string {
+ if n == 0 {
+ return "0"
+ }
+ result := ""
+ for n > 0 {
+ result = string(rune('0'+n%10)) + result
+ n /= 10
+ }
+ return result
+}
+
+// BuildSoftFailureToolResult creates a tool_result content for Claude.
+// This is what Claude will see when a tool call is truncated.
+// Returns a string that should be used as the tool_result content.
+func BuildSoftFailureToolResult(info TruncationInfo) string {
+ msg := BuildSoftFailureMessage(info)
+
+ var sb strings.Builder
+ sb.WriteString("TOOL_CALL_INCOMPLETE\n")
+ sb.WriteString("status: ")
+ sb.WriteString(msg.Status)
+ sb.WriteString("\n")
+ sb.WriteString("reason: ")
+ sb.WriteString(msg.Reason)
+ sb.WriteString("\n")
+
+ if msg.Context != "" {
+ sb.WriteString("context: ")
+ sb.WriteString(msg.Context)
+ sb.WriteString("\n")
+ }
+
+ sb.WriteString("\n")
+ for _, line := range msg.Guidance {
+ if line != "" {
+ sb.WriteString(line)
+ sb.WriteString("\n")
+ }
+ }
+
+ return sb.String()
+}
+
+// CreateTruncationToolResult creates a KiroToolUse that represents a soft failure.
+// Instead of returning the truncated tool_use, we return a tool with a special
+// error result that guides Claude to retry with smaller chunks.
+//
+// This is the key mechanism for "soft failure":
+// - stop_reason remains "tool_use" so Claude continues
+// - The tool_result content explains the issue and how to fix it
+// - Claude will read this and adjust its approach
+func CreateTruncationToolResult(info TruncationInfo) KiroToolUse {
+ // We create a pseudo tool_use that represents the failed attempt
+ // The executor will convert this to a tool_result with the guidance message
+ return KiroToolUse{
+ ToolUseID: info.ToolUseID,
+ Name: info.ToolName,
+ Input: nil, // No input since it was truncated
+ IsTruncated: true,
+ TruncationInfo: &info,
+ }
+}
diff --git a/internal/translator/kiro/common/constants.go b/internal/translator/kiro/common/constants.go
new file mode 100644
index 00000000..3016947c
--- /dev/null
+++ b/internal/translator/kiro/common/constants.go
@@ -0,0 +1,103 @@
+// Package common provides shared constants and utilities for Kiro translator.
+package common
+
+const (
+ // KiroMaxToolDescLen is the maximum description length for Kiro API tools.
+ // Kiro API limit is 10240 bytes, leave room for "..."
+ KiroMaxToolDescLen = 10237
+
+ // ToolCompressionTargetSize is the target total size for compressed tools (20KB).
+ // If tools exceed this size, compression will be applied.
+ ToolCompressionTargetSize = 20 * 1024 // 20KB
+
+ // MinToolDescriptionLength is the minimum description length after compression.
+ // Descriptions will not be shortened below this length.
+ MinToolDescriptionLength = 50
+
+ // ThinkingStartTag is the start tag for thinking blocks in responses.
+ ThinkingStartTag = ""
+
+ // ThinkingEndTag is the end tag for thinking blocks in responses.
+ ThinkingEndTag = ""
+
+ // CodeFenceMarker is the markdown code fence marker.
+ CodeFenceMarker = "```"
+
+ // AltCodeFenceMarker is the alternative markdown code fence marker.
+ AltCodeFenceMarker = "~~~"
+
+ // InlineCodeMarker is the markdown inline code marker (backtick).
+ InlineCodeMarker = "`"
+
+ // DefaultAssistantContentWithTools is the fallback content for assistant messages
+ // that have tool_use but no text content. Kiro API requires non-empty content.
+ // IMPORTANT: Use a minimal neutral string that the model won't mimic in responses.
+ // Previously "I'll help you with that." which caused the model to parrot it back.
+ DefaultAssistantContentWithTools = "."
+
+ // DefaultAssistantContent is the fallback content for assistant messages
+ // that have no content at all. Kiro API requires non-empty content.
+ // IMPORTANT: Use a minimal neutral string that the model won't mimic in responses.
+ // Previously "I understand." which could leak into model behavior.
+ DefaultAssistantContent = "."
+
+ // DefaultUserContentWithToolResults is the fallback content for user messages
+ // that have only tool_result (no text). Kiro API requires non-empty content.
+ DefaultUserContentWithToolResults = "Tool results provided."
+
+ // DefaultUserContent is the fallback content for user messages
+ // that have no content at all. Kiro API requires non-empty content.
+ DefaultUserContent = "Continue"
+
+ // KiroAgenticSystemPrompt is injected only for -agentic models to prevent timeouts on large writes.
+ // AWS Kiro API has a 2-3 minute timeout for large file write operations.
+ KiroAgenticSystemPrompt = `
+# CRITICAL: CHUNKED WRITE PROTOCOL (MANDATORY)
+
+You MUST follow these rules for ALL file operations. Violation causes server timeouts and task failure.
+
+## ABSOLUTE LIMITS
+- **MAXIMUM 350 LINES** per single write/edit operation - NO EXCEPTIONS
+- **RECOMMENDED 300 LINES** or less for optimal performance
+- **NEVER** write entire files in one operation if >300 lines
+
+## MANDATORY CHUNKED WRITE STRATEGY
+
+### For NEW FILES (>300 lines total):
+1. FIRST: Write initial chunk (first 250-300 lines) using write_to_file/fsWrite
+2. THEN: Append remaining content in 250-300 line chunks using file append operations
+3. REPEAT: Continue appending until complete
+
+### For EDITING EXISTING FILES:
+1. Use surgical edits (apply_diff/targeted edits) - change ONLY what's needed
+2. NEVER rewrite entire files - use incremental modifications
+3. Split large refactors into multiple small, focused edits
+
+### For LARGE CODE GENERATION:
+1. Generate in logical sections (imports, types, functions separately)
+2. Write each section as a separate operation
+3. Use append operations for subsequent sections
+
+## EXAMPLES OF CORRECT BEHAVIOR
+
+✅ CORRECT: Writing a 600-line file
+- Operation 1: Write lines 1-300 (initial file creation)
+- Operation 2: Append lines 301-600
+
+✅ CORRECT: Editing multiple functions
+- Operation 1: Edit function A
+- Operation 2: Edit function B
+- Operation 3: Edit function C
+
+❌ WRONG: Writing 500 lines in single operation → TIMEOUT
+❌ WRONG: Rewriting entire file to change 5 lines → TIMEOUT
+❌ WRONG: Generating massive code blocks without chunking → TIMEOUT
+
+## WHY THIS MATTERS
+- Server has 2-3 minute timeout for operations
+- Large writes exceed timeout and FAIL completely
+- Chunked writes are FASTER and more RELIABLE
+- Failed writes waste time and require retry
+
+REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation.`
+)
diff --git a/internal/translator/kiro/common/message_merge.go b/internal/translator/kiro/common/message_merge.go
new file mode 100644
index 00000000..2765fc6e
--- /dev/null
+++ b/internal/translator/kiro/common/message_merge.go
@@ -0,0 +1,160 @@
+// Package common provides shared utilities for Kiro translators.
+package common
+
+import (
+ "encoding/json"
+
+ "github.com/tidwall/gjson"
+)
+
+// MergeAdjacentMessages merges adjacent messages with the same role.
+// This reduces API call complexity and improves compatibility.
+// Based on AIClient-2-API implementation.
+// NOTE: Tool messages are NOT merged because each has a unique tool_call_id that must be preserved.
+func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
+ if len(messages) <= 1 {
+ return messages
+ }
+
+ var merged []gjson.Result
+ for _, msg := range messages {
+ if len(merged) == 0 {
+ merged = append(merged, msg)
+ continue
+ }
+
+ lastMsg := merged[len(merged)-1]
+ currentRole := msg.Get("role").String()
+ lastRole := lastMsg.Get("role").String()
+
+ // Don't merge tool messages - each has a unique tool_call_id
+ if currentRole == "tool" || lastRole == "tool" {
+ merged = append(merged, msg)
+ continue
+ }
+
+ if currentRole == lastRole {
+ // Merge content from current message into last message
+ mergedContent := mergeMessageContent(lastMsg, msg)
+ var mergedToolCalls []interface{}
+ if currentRole == "assistant" {
+ // Preserve assistant tool_calls when adjacent assistant messages are merged.
+ mergedToolCalls = mergeToolCalls(lastMsg.Get("tool_calls"), msg.Get("tool_calls"))
+ }
+
+ // Create a new merged message JSON.
+ mergedMsg := createMergedMessage(lastRole, mergedContent, mergedToolCalls)
+ merged[len(merged)-1] = gjson.Parse(mergedMsg)
+ } else {
+ merged = append(merged, msg)
+ }
+ }
+
+ return merged
+}
+
+// mergeMessageContent merges the content of two messages with the same role.
+// Handles both string content and array content (with text, tool_use, tool_result blocks).
+func mergeMessageContent(msg1, msg2 gjson.Result) string {
+ content1 := msg1.Get("content")
+ content2 := msg2.Get("content")
+
+ // Extract content blocks from both messages
+ var blocks1, blocks2 []map[string]interface{}
+
+ if content1.IsArray() {
+ for _, block := range content1.Array() {
+ blocks1 = append(blocks1, blockToMap(block))
+ }
+ } else if content1.Type == gjson.String {
+ blocks1 = append(blocks1, map[string]interface{}{
+ "type": "text",
+ "text": content1.String(),
+ })
+ }
+
+ if content2.IsArray() {
+ for _, block := range content2.Array() {
+ blocks2 = append(blocks2, blockToMap(block))
+ }
+ } else if content2.Type == gjson.String {
+ blocks2 = append(blocks2, map[string]interface{}{
+ "type": "text",
+ "text": content2.String(),
+ })
+ }
+
+ // Merge text blocks if both end/start with text
+ if len(blocks1) > 0 && len(blocks2) > 0 {
+ if blocks1[len(blocks1)-1]["type"] == "text" && blocks2[0]["type"] == "text" {
+ // Merge the last text block of msg1 with the first text block of msg2
+ text1 := blocks1[len(blocks1)-1]["text"].(string)
+ text2 := blocks2[0]["text"].(string)
+ blocks1[len(blocks1)-1]["text"] = text1 + "\n" + text2
+ blocks2 = blocks2[1:] // Remove the merged block from blocks2
+ }
+ }
+
+ // Combine all blocks
+ allBlocks := append(blocks1, blocks2...)
+
+ // Convert to JSON
+ result, _ := json.Marshal(allBlocks)
+ return string(result)
+}
+
+// blockToMap converts a gjson.Result block to a map[string]interface{}
+func blockToMap(block gjson.Result) map[string]interface{} {
+ result := make(map[string]interface{})
+ block.ForEach(func(key, value gjson.Result) bool {
+ if value.IsObject() {
+ result[key.String()] = blockToMap(value)
+ } else if value.IsArray() {
+ var arr []interface{}
+ for _, item := range value.Array() {
+ if item.IsObject() {
+ arr = append(arr, blockToMap(item))
+ } else {
+ arr = append(arr, item.Value())
+ }
+ }
+ result[key.String()] = arr
+ } else {
+ result[key.String()] = value.Value()
+ }
+ return true
+ })
+ return result
+}
+
+// createMergedMessage creates a JSON string for a merged message.
+// toolCalls is optional and only emitted for assistant role.
+func createMergedMessage(role string, content string, toolCalls []interface{}) string {
+ msg := map[string]interface{}{
+ "role": role,
+ "content": json.RawMessage(content),
+ }
+ if role == "assistant" && len(toolCalls) > 0 {
+ msg["tool_calls"] = toolCalls
+ }
+ result, _ := json.Marshal(msg)
+ return string(result)
+}
+
+// mergeToolCalls combines tool_calls from two assistant messages while preserving order.
+func mergeToolCalls(tc1, tc2 gjson.Result) []interface{} {
+ var merged []interface{}
+
+ if tc1.IsArray() {
+ for _, tc := range tc1.Array() {
+ merged = append(merged, tc.Value())
+ }
+ }
+ if tc2.IsArray() {
+ for _, tc := range tc2.Array() {
+ merged = append(merged, tc.Value())
+ }
+ }
+
+ return merged
+}
diff --git a/internal/translator/kiro/common/message_merge_test.go b/internal/translator/kiro/common/message_merge_test.go
new file mode 100644
index 00000000..a9cb7a28
--- /dev/null
+++ b/internal/translator/kiro/common/message_merge_test.go
@@ -0,0 +1,106 @@
+package common
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/tidwall/gjson"
+)
+
+func parseMessages(t *testing.T, raw string) []gjson.Result {
+ t.Helper()
+ parsed := gjson.Parse(raw)
+ if !parsed.IsArray() {
+ t.Fatalf("expected JSON array, got: %s", raw)
+ }
+ return parsed.Array()
+}
+
+func TestMergeAdjacentMessages_AssistantMergePreservesToolCalls(t *testing.T) {
+ messages := parseMessages(t, `[
+ {"role":"assistant","content":"part1"},
+ {
+ "role":"assistant",
+ "content":"part2",
+ "tool_calls":[
+ {
+ "id":"call_1",
+ "type":"function",
+ "function":{"name":"Read","arguments":"{}"}
+ }
+ ]
+ },
+ {"role":"tool","tool_call_id":"call_1","content":"ok"}
+ ]`)
+
+ merged := MergeAdjacentMessages(messages)
+ if len(merged) != 2 {
+ t.Fatalf("expected 2 messages after merge, got %d", len(merged))
+ }
+
+ assistant := merged[0]
+ if assistant.Get("role").String() != "assistant" {
+ t.Fatalf("expected first message role assistant, got %q", assistant.Get("role").String())
+ }
+
+ toolCalls := assistant.Get("tool_calls")
+ if !toolCalls.IsArray() || len(toolCalls.Array()) != 1 {
+ t.Fatalf("expected assistant.tool_calls length 1, got: %s", toolCalls.Raw)
+ }
+ if toolCalls.Array()[0].Get("id").String() != "call_1" {
+ t.Fatalf("expected tool call id call_1, got %q", toolCalls.Array()[0].Get("id").String())
+ }
+
+ contentRaw := assistant.Get("content").Raw
+ if !strings.Contains(contentRaw, "part1") || !strings.Contains(contentRaw, "part2") {
+ t.Fatalf("expected merged content to contain both parts, got: %s", contentRaw)
+ }
+
+ if merged[1].Get("role").String() != "tool" {
+ t.Fatalf("expected second message role tool, got %q", merged[1].Get("role").String())
+ }
+}
+
+func TestMergeAdjacentMessages_AssistantMergeCombinesMultipleToolCalls(t *testing.T) {
+ messages := parseMessages(t, `[
+ {
+ "role":"assistant",
+ "content":"first",
+ "tool_calls":[
+ {"id":"call_1","type":"function","function":{"name":"Read","arguments":"{}"}}
+ ]
+ },
+ {
+ "role":"assistant",
+ "content":"second",
+ "tool_calls":[
+ {"id":"call_2","type":"function","function":{"name":"Write","arguments":"{}"}}
+ ]
+ }
+ ]`)
+
+ merged := MergeAdjacentMessages(messages)
+ if len(merged) != 1 {
+ t.Fatalf("expected 1 message after merge, got %d", len(merged))
+ }
+
+ toolCalls := merged[0].Get("tool_calls").Array()
+ if len(toolCalls) != 2 {
+ t.Fatalf("expected 2 merged tool calls, got %d", len(toolCalls))
+ }
+ if toolCalls[0].Get("id").String() != "call_1" || toolCalls[1].Get("id").String() != "call_2" {
+ t.Fatalf("unexpected merged tool call ids: %q, %q", toolCalls[0].Get("id").String(), toolCalls[1].Get("id").String())
+ }
+}
+
+func TestMergeAdjacentMessages_ToolMessagesRemainUnmerged(t *testing.T) {
+ messages := parseMessages(t, `[
+ {"role":"tool","tool_call_id":"call_1","content":"r1"},
+ {"role":"tool","tool_call_id":"call_2","content":"r2"}
+ ]`)
+
+ merged := MergeAdjacentMessages(messages)
+ if len(merged) != 2 {
+ t.Fatalf("expected tool messages to remain separate, got %d", len(merged))
+ }
+}
diff --git a/internal/translator/kiro/common/utils.go b/internal/translator/kiro/common/utils.go
new file mode 100644
index 00000000..f5f5788a
--- /dev/null
+++ b/internal/translator/kiro/common/utils.go
@@ -0,0 +1,16 @@
+// Package common provides shared constants and utilities for Kiro translator.
+package common
+
+// GetString safely extracts a string from a map.
+// Returns empty string if the key doesn't exist or the value is not a string.
+func GetString(m map[string]interface{}, key string) string {
+ if v, ok := m[key].(string); ok {
+ return v
+ }
+ return ""
+}
+
+// GetStringValue is an alias for GetString for backward compatibility.
+func GetStringValue(m map[string]interface{}, key string) string {
+ return GetString(m, key)
+}
\ No newline at end of file
diff --git a/internal/translator/kiro/openai/init.go b/internal/translator/kiro/openai/init.go
new file mode 100644
index 00000000..653eed45
--- /dev/null
+++ b/internal/translator/kiro/openai/init.go
@@ -0,0 +1,20 @@
+// Package openai provides translation between OpenAI Chat Completions and Kiro formats.
+package openai
+
+import (
+ . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
+)
+
+func init() {
+ translator.Register(
+ OpenAI, // source format
+ Kiro, // target format
+ ConvertOpenAIRequestToKiro,
+ interfaces.TranslateResponse{
+ Stream: ConvertKiroStreamToOpenAI,
+ NonStream: ConvertKiroNonStreamToOpenAI,
+ },
+ )
+}
\ No newline at end of file
diff --git a/internal/translator/kiro/openai/kiro_openai.go b/internal/translator/kiro/openai/kiro_openai.go
new file mode 100644
index 00000000..03962b9f
--- /dev/null
+++ b/internal/translator/kiro/openai/kiro_openai.go
@@ -0,0 +1,371 @@
+// Package openai provides translation between OpenAI Chat Completions and Kiro formats.
+// This package enables direct OpenAI → Kiro translation, bypassing the Claude intermediate layer.
+//
+// The Kiro executor generates Claude-compatible SSE format internally, so the streaming response
+// translation converts from Claude SSE format to OpenAI SSE format.
+package openai
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "strings"
+
+ kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
+ log "github.com/sirupsen/logrus"
+ "github.com/tidwall/gjson"
+)
+
+// ConvertKiroStreamToOpenAI converts Kiro streaming response to OpenAI format.
+// The Kiro executor emits Claude-compatible SSE events, so this function translates
+// from Claude SSE format to OpenAI SSE format.
+//
+// Claude SSE format:
+// - event: message_start\ndata: {...}
+// - event: content_block_start\ndata: {...}
+// - event: content_block_delta\ndata: {...}
+// - event: content_block_stop\ndata: {...}
+// - event: message_delta\ndata: {...}
+// - event: message_stop\ndata: {...}
+//
+// OpenAI SSE format:
+// - data: {"id":"...","object":"chat.completion.chunk",...}
+// - data: [DONE]
+func ConvertKiroStreamToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
+ // Initialize state if needed
+ if *param == nil {
+ *param = NewOpenAIStreamState(model)
+ }
+ state := (*param).(*OpenAIStreamState)
+
+ // Parse the Claude SSE event
+ responseStr := string(rawResponse)
+
+ // Handle raw event format (event: xxx\ndata: {...})
+ var eventType string
+ var eventData string
+
+ if strings.HasPrefix(responseStr, "event:") {
+ // Parse event type and data
+ lines := strings.SplitN(responseStr, "\n", 2)
+ if len(lines) >= 1 {
+ eventType = strings.TrimSpace(strings.TrimPrefix(lines[0], "event:"))
+ }
+ if len(lines) >= 2 && strings.HasPrefix(lines[1], "data:") {
+ eventData = strings.TrimSpace(strings.TrimPrefix(lines[1], "data:"))
+ }
+ } else if strings.HasPrefix(responseStr, "data:") {
+ // Just data line
+ eventData = strings.TrimSpace(strings.TrimPrefix(responseStr, "data:"))
+ } else {
+ // Try to parse as raw JSON
+ eventData = strings.TrimSpace(responseStr)
+ }
+
+ if eventData == "" {
+ return []string{}
+ }
+
+ // Parse the event data as JSON
+ eventJSON := gjson.Parse(eventData)
+ if !eventJSON.Exists() {
+ return []string{}
+ }
+
+ // Determine event type from JSON if not already set
+ if eventType == "" {
+ eventType = eventJSON.Get("type").String()
+ }
+
+ var results []string
+
+ switch eventType {
+ case "message_start":
+ // Send first chunk with role
+ firstChunk := BuildOpenAISSEFirstChunk(state)
+ results = append(results, firstChunk)
+
+ case "content_block_start":
+ // Check block type
+ blockType := eventJSON.Get("content_block.type").String()
+ switch blockType {
+ case "text":
+ // Text block starting - nothing to emit yet
+ case "thinking":
+ // Thinking block starting - nothing to emit yet for OpenAI
+ case "tool_use":
+ // Tool use block starting
+ toolUseID := eventJSON.Get("content_block.id").String()
+ toolName := eventJSON.Get("content_block.name").String()
+ chunk := BuildOpenAISSEToolCallStart(state, toolUseID, toolName)
+ results = append(results, chunk)
+ state.ToolCallIndex++
+ }
+
+ case "content_block_delta":
+ deltaType := eventJSON.Get("delta.type").String()
+ switch deltaType {
+ case "text_delta":
+ textDelta := eventJSON.Get("delta.text").String()
+ if textDelta != "" {
+ chunk := BuildOpenAISSETextDelta(state, textDelta)
+ results = append(results, chunk)
+ }
+ case "thinking_delta":
+ // Convert thinking to reasoning_content for o1-style compatibility
+ thinkingDelta := eventJSON.Get("delta.thinking").String()
+ if thinkingDelta != "" {
+ chunk := BuildOpenAISSEReasoningDelta(state, thinkingDelta)
+ results = append(results, chunk)
+ }
+ case "input_json_delta":
+ // Tool call arguments delta
+ partialJSON := eventJSON.Get("delta.partial_json").String()
+ if partialJSON != "" {
+ // Get the tool index from content block index
+ blockIndex := int(eventJSON.Get("index").Int())
+ chunk := BuildOpenAISSEToolCallArgumentsDelta(state, partialJSON, blockIndex-1) // Adjust for 0-based tool index
+ results = append(results, chunk)
+ }
+ }
+
+ case "content_block_stop":
+ // Content block ended - nothing to emit for OpenAI
+
+ case "message_delta":
+ // Message delta with stop_reason
+ stopReason := eventJSON.Get("delta.stop_reason").String()
+ finishReason := mapKiroStopReasonToOpenAI(stopReason)
+ if finishReason != "" {
+ chunk := BuildOpenAISSEFinish(state, finishReason)
+ results = append(results, chunk)
+ }
+
+ // Extract usage if present
+ if eventJSON.Get("usage").Exists() {
+ inputTokens := eventJSON.Get("usage.input_tokens").Int()
+ outputTokens := eventJSON.Get("usage.output_tokens").Int()
+ usageInfo := usage.Detail{
+ InputTokens: inputTokens,
+ OutputTokens: outputTokens,
+ TotalTokens: inputTokens + outputTokens,
+ }
+ chunk := BuildOpenAISSEUsage(state, usageInfo)
+ results = append(results, chunk)
+ }
+
+ case "message_stop":
+ // Final event - do NOT emit [DONE] here
+ // The handler layer (openai_handlers.go) will send [DONE] when the stream closes
+ // Emitting [DONE] here would cause duplicate [DONE] markers
+
+ case "ping":
+ // Ping event with usage - optionally emit usage chunk
+ if eventJSON.Get("usage").Exists() {
+ inputTokens := eventJSON.Get("usage.input_tokens").Int()
+ outputTokens := eventJSON.Get("usage.output_tokens").Int()
+ usageInfo := usage.Detail{
+ InputTokens: inputTokens,
+ OutputTokens: outputTokens,
+ TotalTokens: inputTokens + outputTokens,
+ }
+ chunk := BuildOpenAISSEUsage(state, usageInfo)
+ results = append(results, chunk)
+ }
+ }
+
+ return results
+}
+
+// ConvertKiroNonStreamToOpenAI converts Kiro non-streaming response to OpenAI format.
+// The Kiro executor returns Claude-compatible JSON responses, so this function translates
+// from Claude format to OpenAI format.
+func ConvertKiroNonStreamToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
+ // Parse the Claude-format response
+ response := gjson.ParseBytes(rawResponse)
+
+ // Extract content
+ var content string
+ var reasoningContent string
+ var toolUses []KiroToolUse
+ var stopReason string
+
+ // Get stop_reason
+ stopReason = response.Get("stop_reason").String()
+
+ // Process content blocks
+ contentBlocks := response.Get("content")
+ if contentBlocks.IsArray() {
+ for _, block := range contentBlocks.Array() {
+ blockType := block.Get("type").String()
+ switch blockType {
+ case "text":
+ content += block.Get("text").String()
+ case "thinking":
+ // Convert thinking blocks to reasoning_content for OpenAI format
+ reasoningContent += block.Get("thinking").String()
+ case "tool_use":
+ toolUseID := block.Get("id").String()
+ toolName := block.Get("name").String()
+ toolInput := block.Get("input")
+
+ var inputMap map[string]interface{}
+ if toolInput.IsObject() {
+ inputMap = make(map[string]interface{})
+ toolInput.ForEach(func(key, value gjson.Result) bool {
+ inputMap[key.String()] = value.Value()
+ return true
+ })
+ }
+
+ toolUses = append(toolUses, KiroToolUse{
+ ToolUseID: toolUseID,
+ Name: toolName,
+ Input: inputMap,
+ })
+ }
+ }
+ }
+
+ // Extract usage
+ usageInfo := usage.Detail{
+ InputTokens: response.Get("usage.input_tokens").Int(),
+ OutputTokens: response.Get("usage.output_tokens").Int(),
+ }
+ usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens
+
+ // Build OpenAI response with reasoning_content support
+ openaiResponse := BuildOpenAIResponseWithReasoning(content, reasoningContent, toolUses, model, usageInfo, stopReason)
+ return string(openaiResponse)
+}
+
+// ParseClaudeEvent parses a Claude SSE event and returns the event type and data
+func ParseClaudeEvent(rawEvent []byte) (eventType string, eventData []byte) {
+ lines := bytes.Split(rawEvent, []byte("\n"))
+ for _, line := range lines {
+ line = bytes.TrimSpace(line)
+ if bytes.HasPrefix(line, []byte("event:")) {
+ eventType = string(bytes.TrimSpace(bytes.TrimPrefix(line, []byte("event:"))))
+ } else if bytes.HasPrefix(line, []byte("data:")) {
+ eventData = bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
+ }
+ }
+ return eventType, eventData
+}
+
+// ExtractThinkingFromContent parses content to extract thinking blocks.
+// Returns cleaned content (without thinking tags) and whether thinking was found.
+func ExtractThinkingFromContent(content string) (string, string, bool) {
+ if !strings.Contains(content, kirocommon.ThinkingStartTag) {
+ return content, "", false
+ }
+
+ var cleanedContent strings.Builder
+ var thinkingContent strings.Builder
+ hasThinking := false
+ remaining := content
+
+ for len(remaining) > 0 {
+ startIdx := strings.Index(remaining, kirocommon.ThinkingStartTag)
+ if startIdx == -1 {
+ cleanedContent.WriteString(remaining)
+ break
+ }
+
+ // Add content before thinking tag
+ cleanedContent.WriteString(remaining[:startIdx])
+
+ // Move past opening tag
+ remaining = remaining[startIdx+len(kirocommon.ThinkingStartTag):]
+
+ // Find closing tag
+ endIdx := strings.Index(remaining, kirocommon.ThinkingEndTag)
+ if endIdx == -1 {
+ // No closing tag - treat rest as thinking
+ thinkingContent.WriteString(remaining)
+ hasThinking = true
+ break
+ }
+
+ // Extract thinking content
+ thinkingContent.WriteString(remaining[:endIdx])
+ hasThinking = true
+ remaining = remaining[endIdx+len(kirocommon.ThinkingEndTag):]
+ }
+
+ return strings.TrimSpace(cleanedContent.String()), strings.TrimSpace(thinkingContent.String()), hasThinking
+}
+
+// ConvertOpenAIToolsToKiroFormat is a helper that converts OpenAI tools format to Kiro format
+func ConvertOpenAIToolsToKiroFormat(tools []map[string]interface{}) []KiroToolWrapper {
+ var kiroTools []KiroToolWrapper
+
+ for _, tool := range tools {
+ toolType, _ := tool["type"].(string)
+ if toolType != "function" {
+ continue
+ }
+
+ fn, ok := tool["function"].(map[string]interface{})
+ if !ok {
+ continue
+ }
+
+ name := kirocommon.GetString(fn, "name")
+ description := kirocommon.GetString(fn, "description")
+ parameters := ensureKiroInputSchema(fn["parameters"])
+
+ if name == "" {
+ continue
+ }
+
+ if description == "" {
+ description = "Tool: " + name
+ }
+
+ kiroTools = append(kiroTools, KiroToolWrapper{
+ ToolSpecification: KiroToolSpecification{
+ Name: name,
+ Description: description,
+ InputSchema: KiroInputSchema{JSON: parameters},
+ },
+ })
+ }
+
+ return kiroTools
+}
+
+// OpenAIStreamParams holds parameters for OpenAI streaming conversion
+type OpenAIStreamParams struct {
+ State *OpenAIStreamState
+ ThinkingState *ThinkingTagState
+ ToolCallsEmitted map[string]bool
+}
+
+// NewOpenAIStreamParams creates new streaming parameters
+func NewOpenAIStreamParams(model string) *OpenAIStreamParams {
+ return &OpenAIStreamParams{
+ State: NewOpenAIStreamState(model),
+ ThinkingState: NewThinkingTagState(),
+ ToolCallsEmitted: make(map[string]bool),
+ }
+}
+
+// ConvertClaudeToolUseToOpenAI converts a Claude tool_use block to OpenAI tool_calls format
+func ConvertClaudeToolUseToOpenAI(toolUseID, toolName string, input map[string]interface{}) map[string]interface{} {
+ inputJSON, _ := json.Marshal(input)
+ return map[string]interface{}{
+ "id": toolUseID,
+ "type": "function",
+ "function": map[string]interface{}{
+ "name": toolName,
+ "arguments": string(inputJSON),
+ },
+ }
+}
+
+// LogStreamEvent logs a streaming event for debugging
+func LogStreamEvent(eventType, data string) {
+ log.Debugf("kiro-openai: stream event type=%s, data_len=%d", eventType, len(data))
+}
diff --git a/internal/translator/kiro/openai/kiro_openai_request.go b/internal/translator/kiro/openai/kiro_openai_request.go
new file mode 100644
index 00000000..9515848f
--- /dev/null
+++ b/internal/translator/kiro/openai/kiro_openai_request.go
@@ -0,0 +1,927 @@
+// Package openai provides request translation from OpenAI Chat Completions to Kiro format.
+// It handles parsing and transforming OpenAI API requests into the Kiro/Amazon Q API format,
+// extracting model information, system instructions, message contents, and tool declarations.
+package openai
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+ "time"
+ "unicode/utf8"
+
+ "github.com/google/uuid"
+ kiroclaude "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude"
+ kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
+ log "github.com/sirupsen/logrus"
+ "github.com/tidwall/gjson"
+)
+
+// Kiro API request structs - reuse from kiroclaude package structure
+
+// KiroPayload is the top-level request structure for Kiro API
+type KiroPayload struct {
+ ConversationState KiroConversationState `json:"conversationState"`
+ ProfileArn string `json:"profileArn,omitempty"`
+ InferenceConfig *KiroInferenceConfig `json:"inferenceConfig,omitempty"`
+}
+
+// KiroInferenceConfig contains inference parameters for the Kiro API.
+type KiroInferenceConfig struct {
+ MaxTokens int `json:"maxTokens,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"topP,omitempty"`
+}
+
+// KiroConversationState holds the conversation context
+type KiroConversationState struct {
+ ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL"
+ ConversationID string `json:"conversationId"`
+ CurrentMessage KiroCurrentMessage `json:"currentMessage"`
+ History []KiroHistoryMessage `json:"history,omitempty"`
+}
+
+// KiroCurrentMessage wraps the current user message
+type KiroCurrentMessage struct {
+ UserInputMessage KiroUserInputMessage `json:"userInputMessage"`
+}
+
+// KiroHistoryMessage represents a message in the conversation history
+type KiroHistoryMessage struct {
+ UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"`
+ AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"`
+}
+
+// KiroImage represents an image in Kiro API format
+type KiroImage struct {
+ Format string `json:"format"`
+ Source KiroImageSource `json:"source"`
+}
+
+// KiroImageSource contains the image data
+type KiroImageSource struct {
+ Bytes string `json:"bytes"` // base64 encoded image data
+}
+
+// KiroUserInputMessage represents a user message
+type KiroUserInputMessage struct {
+ Content string `json:"content"`
+ ModelID string `json:"modelId"`
+ Origin string `json:"origin"`
+ Images []KiroImage `json:"images,omitempty"`
+ UserInputMessageContext *KiroUserInputMessageContext `json:"userInputMessageContext,omitempty"`
+}
+
+// KiroUserInputMessageContext contains tool-related context
+type KiroUserInputMessageContext struct {
+ ToolResults []KiroToolResult `json:"toolResults,omitempty"`
+ Tools []KiroToolWrapper `json:"tools,omitempty"`
+}
+
+// KiroToolResult represents a tool execution result
+type KiroToolResult struct {
+ Content []KiroTextContent `json:"content"`
+ Status string `json:"status"`
+ ToolUseID string `json:"toolUseId"`
+}
+
+// KiroTextContent represents text content
+type KiroTextContent struct {
+ Text string `json:"text"`
+}
+
+// KiroToolWrapper wraps a tool specification
+type KiroToolWrapper struct {
+ ToolSpecification KiroToolSpecification `json:"toolSpecification"`
+}
+
+// KiroToolSpecification defines a tool's schema
+type KiroToolSpecification struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ InputSchema KiroInputSchema `json:"inputSchema"`
+}
+
+// KiroInputSchema wraps the JSON schema for tool input
+type KiroInputSchema struct {
+ JSON interface{} `json:"json"`
+}
+
+// KiroAssistantResponseMessage represents an assistant message
+type KiroAssistantResponseMessage struct {
+ Content string `json:"content"`
+ ToolUses []KiroToolUse `json:"toolUses,omitempty"`
+}
+
+// KiroToolUse represents a tool invocation by the assistant
+type KiroToolUse struct {
+ ToolUseID string `json:"toolUseId"`
+ Name string `json:"name"`
+ Input map[string]interface{} `json:"input"`
+}
+
+// ConvertOpenAIRequestToKiro converts an OpenAI Chat Completions request to Kiro format.
+// This is the main entry point for request translation.
+// Note: The actual payload building happens in the executor, this just passes through
+// the OpenAI format which will be converted by BuildKiroPayloadFromOpenAI.
+func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
+ // Pass through the OpenAI format - actual conversion happens in BuildKiroPayloadFromOpenAI
+ return inputRawJSON
+}
+
+// BuildKiroPayloadFromOpenAI constructs the Kiro API request payload from OpenAI format.
+// Supports tool calling - tools are passed via userInputMessageContext.
+// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE.
+// isAgentic parameter enables chunked write optimization prompt for -agentic model variants.
+// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode).
+// headers parameter allows checking Anthropic-Beta header for thinking mode detection.
+// metadata parameter is kept for API compatibility but no longer used for thinking configuration.
+// Returns the payload and a boolean indicating whether thinking mode was injected.
+func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, headers http.Header, metadata map[string]any) ([]byte, bool) {
+ // Extract max_tokens for potential use in inferenceConfig
+ // Handle -1 as "use maximum" (Kiro max output is ~32000 tokens)
+ const kiroMaxOutputTokens = 32000
+ var maxTokens int64
+ if mt := gjson.GetBytes(openaiBody, "max_tokens"); mt.Exists() {
+ maxTokens = mt.Int()
+ if maxTokens == -1 {
+ maxTokens = kiroMaxOutputTokens
+ log.Debugf("kiro-openai: max_tokens=-1 converted to %d", kiroMaxOutputTokens)
+ }
+ }
+
+ // Extract temperature if specified
+ var temperature float64
+ var hasTemperature bool
+ if temp := gjson.GetBytes(openaiBody, "temperature"); temp.Exists() {
+ temperature = temp.Float()
+ hasTemperature = true
+ }
+
+ // Extract top_p if specified
+ var topP float64
+ var hasTopP bool
+ if tp := gjson.GetBytes(openaiBody, "top_p"); tp.Exists() {
+ topP = tp.Float()
+ hasTopP = true
+ log.Debugf("kiro-openai: extracted top_p: %.2f", topP)
+ }
+
+ // Normalize origin value for Kiro API compatibility
+ origin = normalizeOrigin(origin)
+ log.Debugf("kiro-openai: normalized origin value: %s", origin)
+
+ messages := gjson.GetBytes(openaiBody, "messages")
+
+ // For chat-only mode, don't include tools
+ var tools gjson.Result
+ if !isChatOnly {
+ tools = gjson.GetBytes(openaiBody, "tools")
+ }
+
+ // Extract system prompt from messages
+ systemPrompt := extractSystemPromptFromOpenAI(messages)
+
+ // Inject timestamp context
+ timestamp := time.Now().Format("2006-01-02 15:04:05 MST")
+ timestampContext := fmt.Sprintf("[Context: Current time is %s]", timestamp)
+ if systemPrompt != "" {
+ systemPrompt = timestampContext + "\n\n" + systemPrompt
+ } else {
+ systemPrompt = timestampContext
+ }
+ log.Debugf("kiro-openai: injected timestamp context: %s", timestamp)
+
+ // Inject agentic optimization prompt for -agentic model variants
+ if isAgentic {
+ if systemPrompt != "" {
+ systemPrompt += "\n"
+ }
+ systemPrompt += kirocommon.KiroAgenticSystemPrompt
+ }
+
+ // Handle tool_choice parameter - Kiro doesn't support it natively, so we inject system prompt hints
+ // OpenAI tool_choice values: "none", "auto", "required", or {"type":"function","function":{"name":"..."}}
+ toolChoiceHint := extractToolChoiceHint(openaiBody)
+ if toolChoiceHint != "" {
+ if systemPrompt != "" {
+ systemPrompt += "\n"
+ }
+ systemPrompt += toolChoiceHint
+ log.Debugf("kiro-openai: injected tool_choice hint into system prompt")
+ }
+
+ // Handle response_format parameter - Kiro doesn't support it natively, so we inject system prompt hints
+ // OpenAI response_format: {"type": "json_object"} or {"type": "json_schema", "json_schema": {...}}
+ responseFormatHint := extractResponseFormatHint(openaiBody)
+ if responseFormatHint != "" {
+ if systemPrompt != "" {
+ systemPrompt += "\n"
+ }
+ systemPrompt += responseFormatHint
+ log.Debugf("kiro-openai: injected response_format hint into system prompt")
+ }
+
+ // Check for thinking mode
+ // Supports OpenAI reasoning_effort parameter, model name hints, and Anthropic-Beta header
+ thinkingEnabled := checkThinkingModeFromOpenAIWithHeaders(openaiBody, headers)
+
+ // Convert OpenAI tools to Kiro format
+ kiroTools := convertOpenAIToolsToKiro(tools)
+
+ // Thinking mode implementation:
+ // Kiro API supports official thinking/reasoning mode via tag.
+ // When set to "enabled", Kiro returns reasoning content as official reasoningContentEvent
+ // rather than inline tags in assistantResponseEvent.
+ // We use a high max_thinking_length to allow extensive reasoning.
+ if thinkingEnabled {
+ thinkingHint := `enabled
+200000`
+ if systemPrompt != "" {
+ systemPrompt = thinkingHint + "\n\n" + systemPrompt
+ } else {
+ systemPrompt = thinkingHint
+ }
+ log.Debugf("kiro-openai: injected thinking prompt (official mode)")
+ }
+
+ // Process messages and build history
+ history, currentUserMsg, currentToolResults := processOpenAIMessages(messages, modelID, origin)
+
+ // Build content with system prompt
+ if currentUserMsg != nil {
+ currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, systemPrompt, currentToolResults)
+
+ // Deduplicate currentToolResults
+ currentToolResults = deduplicateToolResults(currentToolResults)
+
+ // Build userInputMessageContext with tools and tool results
+ if len(kiroTools) > 0 || len(currentToolResults) > 0 {
+ currentUserMsg.UserInputMessageContext = &KiroUserInputMessageContext{
+ Tools: kiroTools,
+ ToolResults: currentToolResults,
+ }
+ }
+ }
+
+ // Build payload
+ var currentMessage KiroCurrentMessage
+ if currentUserMsg != nil {
+ currentMessage = KiroCurrentMessage{UserInputMessage: *currentUserMsg}
+ } else {
+ fallbackContent := ""
+ if systemPrompt != "" {
+ fallbackContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n"
+ }
+ currentMessage = KiroCurrentMessage{UserInputMessage: KiroUserInputMessage{
+ Content: fallbackContent,
+ ModelID: modelID,
+ Origin: origin,
+ }}
+ }
+
+ // Build inferenceConfig if we have any inference parameters
+ // Note: Kiro API doesn't actually use max_tokens for thinking budget
+ var inferenceConfig *KiroInferenceConfig
+ if maxTokens > 0 || hasTemperature || hasTopP {
+ inferenceConfig = &KiroInferenceConfig{}
+ if maxTokens > 0 {
+ inferenceConfig.MaxTokens = int(maxTokens)
+ }
+ if hasTemperature {
+ inferenceConfig.Temperature = temperature
+ }
+ if hasTopP {
+ inferenceConfig.TopP = topP
+ }
+ }
+
+ payload := KiroPayload{
+ ConversationState: KiroConversationState{
+ ChatTriggerType: "MANUAL",
+ ConversationID: uuid.New().String(),
+ CurrentMessage: currentMessage,
+ History: history,
+ },
+ ProfileArn: profileArn,
+ InferenceConfig: inferenceConfig,
+ }
+
+ result, err := json.Marshal(payload)
+ if err != nil {
+ log.Debugf("kiro-openai: failed to marshal payload: %v", err)
+ return nil, false
+ }
+
+ return result, thinkingEnabled
+}
+
+// normalizeOrigin normalizes origin value for Kiro API compatibility
+func normalizeOrigin(origin string) string {
+ switch origin {
+ case "KIRO_CLI":
+ return "CLI"
+ case "KIRO_AI_EDITOR":
+ return "AI_EDITOR"
+ case "AMAZON_Q":
+ return "CLI"
+ case "KIRO_IDE":
+ return "AI_EDITOR"
+ default:
+ return origin
+ }
+}
+
+// extractSystemPromptFromOpenAI extracts system prompt from OpenAI messages
+func extractSystemPromptFromOpenAI(messages gjson.Result) string {
+ if !messages.IsArray() {
+ return ""
+ }
+
+ var systemParts []string
+ for _, msg := range messages.Array() {
+ if msg.Get("role").String() == "system" {
+ content := msg.Get("content")
+ if content.Type == gjson.String {
+ systemParts = append(systemParts, content.String())
+ } else if content.IsArray() {
+ // Handle array content format
+ for _, part := range content.Array() {
+ if part.Get("type").String() == "text" {
+ systemParts = append(systemParts, part.Get("text").String())
+ }
+ }
+ }
+ }
+ }
+
+ return strings.Join(systemParts, "\n")
+}
+
+// shortenToolNameIfNeeded shortens tool names that exceed 64 characters.
+// MCP tools often have long names like "mcp__server-name__tool-name".
+// This preserves the "mcp__" prefix and last segment when possible.
+func shortenToolNameIfNeeded(name string) string {
+ const limit = 64
+ if len(name) <= limit {
+ return name
+ }
+ // For MCP tools, try to preserve prefix and last segment
+ if strings.HasPrefix(name, "mcp__") {
+ idx := strings.LastIndex(name, "__")
+ if idx > 0 {
+ cand := "mcp__" + name[idx+2:]
+ if len(cand) > limit {
+ return cand[:limit]
+ }
+ return cand
+ }
+ }
+ return name[:limit]
+}
+
+func ensureKiroInputSchema(parameters interface{}) interface{} {
+ if parameters != nil {
+ return parameters
+ }
+ return map[string]interface{}{
+ "type": "object",
+ "properties": map[string]interface{}{},
+ }
+}
+
+// convertOpenAIToolsToKiro converts OpenAI tools to Kiro format
+func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper {
+ var kiroTools []KiroToolWrapper
+ if !tools.IsArray() {
+ return kiroTools
+ }
+
+ for _, tool := range tools.Array() {
+ // OpenAI tools have type "function" with function definition inside
+ if tool.Get("type").String() != "function" {
+ continue
+ }
+
+ fn := tool.Get("function")
+ if !fn.Exists() {
+ continue
+ }
+
+ name := fn.Get("name").String()
+ description := fn.Get("description").String()
+ parametersResult := fn.Get("parameters")
+ var parameters interface{}
+ if parametersResult.Exists() && parametersResult.Type != gjson.Null {
+ parameters = parametersResult.Value()
+ }
+ parameters = ensureKiroInputSchema(parameters)
+
+ // Shorten tool name if it exceeds 64 characters (common with MCP tools)
+ originalName := name
+ name = shortenToolNameIfNeeded(name)
+ if name != originalName {
+ log.Debugf("kiro-openai: shortened tool name from '%s' to '%s'", originalName, name)
+ }
+
+ // CRITICAL FIX: Kiro API requires non-empty description
+ if strings.TrimSpace(description) == "" {
+ description = fmt.Sprintf("Tool: %s", name)
+ log.Debugf("kiro-openai: tool '%s' has empty description, using default: %s", name, description)
+ }
+
+ // Truncate long descriptions
+ if len(description) > kirocommon.KiroMaxToolDescLen {
+ truncLen := kirocommon.KiroMaxToolDescLen - 30
+ for truncLen > 0 && !utf8.RuneStart(description[truncLen]) {
+ truncLen--
+ }
+ description = description[:truncLen] + "... (description truncated)"
+ }
+
+ kiroTools = append(kiroTools, KiroToolWrapper{
+ ToolSpecification: KiroToolSpecification{
+ Name: name,
+ Description: description,
+ InputSchema: KiroInputSchema{JSON: parameters},
+ },
+ })
+ }
+
+ return kiroTools
+}
+
+// processOpenAIMessages processes OpenAI messages and builds Kiro history
+func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]KiroHistoryMessage, *KiroUserInputMessage, []KiroToolResult) {
+ var history []KiroHistoryMessage
+ var currentUserMsg *KiroUserInputMessage
+ var currentToolResults []KiroToolResult
+
+ if !messages.IsArray() {
+ return history, currentUserMsg, currentToolResults
+ }
+
+ // Merge adjacent messages with the same role
+ messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
+
+ // Track pending tool results that should be attached to the next user message
+ // This is critical for LiteLLM-translated requests where tool results appear
+ // as separate "tool" role messages between assistant and user messages
+ var pendingToolResults []KiroToolResult
+
+ for i, msg := range messagesArray {
+ role := msg.Get("role").String()
+ isLastMessage := i == len(messagesArray)-1
+
+ switch role {
+ case "system":
+ // System messages are handled separately via extractSystemPromptFromOpenAI
+ continue
+
+ case "user":
+ userMsg, toolResults := buildUserMessageFromOpenAI(msg, modelID, origin)
+ // Merge any pending tool results from preceding "tool" role messages
+ toolResults = append(pendingToolResults, toolResults...)
+ pendingToolResults = nil // Reset pending tool results
+
+ if isLastMessage {
+ currentUserMsg = &userMsg
+ currentToolResults = toolResults
+ } else {
+ // CRITICAL: Kiro API requires content to be non-empty for history messages
+ if strings.TrimSpace(userMsg.Content) == "" {
+ if len(toolResults) > 0 {
+ userMsg.Content = "Tool results provided."
+ } else {
+ userMsg.Content = "Continue"
+ }
+ }
+ // For history messages, embed tool results in context
+ if len(toolResults) > 0 {
+ userMsg.UserInputMessageContext = &KiroUserInputMessageContext{
+ ToolResults: toolResults,
+ }
+ }
+ history = append(history, KiroHistoryMessage{
+ UserInputMessage: &userMsg,
+ })
+ }
+
+ case "assistant":
+ assistantMsg := buildAssistantMessageFromOpenAI(msg)
+
+ // If there are pending tool results, we need to insert a synthetic user message
+ // before this assistant message to maintain proper conversation structure
+ if len(pendingToolResults) > 0 {
+ syntheticUserMsg := KiroUserInputMessage{
+ Content: "Tool results provided.",
+ ModelID: modelID,
+ Origin: origin,
+ UserInputMessageContext: &KiroUserInputMessageContext{
+ ToolResults: pendingToolResults,
+ },
+ }
+ history = append(history, KiroHistoryMessage{
+ UserInputMessage: &syntheticUserMsg,
+ })
+ pendingToolResults = nil
+ }
+
+ if isLastMessage {
+ history = append(history, KiroHistoryMessage{
+ AssistantResponseMessage: &assistantMsg,
+ })
+ // Create a "Continue" user message as currentMessage
+ currentUserMsg = &KiroUserInputMessage{
+ Content: "Continue",
+ ModelID: modelID,
+ Origin: origin,
+ }
+ } else {
+ history = append(history, KiroHistoryMessage{
+ AssistantResponseMessage: &assistantMsg,
+ })
+ }
+
+ case "tool":
+ // Tool messages in OpenAI format provide results for tool_calls
+ // These are typically followed by user or assistant messages
+ // Collect them as pending and attach to the next user message
+ toolCallID := msg.Get("tool_call_id").String()
+ content := msg.Get("content").String()
+
+ if toolCallID != "" {
+ toolResult := KiroToolResult{
+ ToolUseID: toolCallID,
+ Content: []KiroTextContent{{Text: content}},
+ Status: "success",
+ }
+ // Collect pending tool results to attach to the next user message
+ pendingToolResults = append(pendingToolResults, toolResult)
+ }
+ }
+ }
+
+ // Handle case where tool results are at the end with no following user message
+ if len(pendingToolResults) > 0 {
+ currentToolResults = append(currentToolResults, pendingToolResults...)
+ // If there's no current user message, create a synthetic one for the tool results
+ if currentUserMsg == nil {
+ currentUserMsg = &KiroUserInputMessage{
+ Content: "Tool results provided.",
+ ModelID: modelID,
+ Origin: origin,
+ }
+ }
+ }
+
+ // Truncate history if too long to prevent Kiro API errors
+ history = truncateHistoryIfNeeded(history)
+
+ return history, currentUserMsg, currentToolResults
+}
+
+const kiroMaxHistoryMessages = 50
+
+func truncateHistoryIfNeeded(history []KiroHistoryMessage) []KiroHistoryMessage {
+ if len(history) <= kiroMaxHistoryMessages {
+ return history
+ }
+
+ log.Debugf("kiro-openai: truncating history from %d to %d messages", len(history), kiroMaxHistoryMessages)
+ return history[len(history)-kiroMaxHistoryMessages:]
+}
+
+// buildUserMessageFromOpenAI builds a user message from OpenAI format and extracts tool results
+func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
+ content := msg.Get("content")
+ var contentBuilder strings.Builder
+ var toolResults []KiroToolResult
+ var images []KiroImage
+
+ if content.IsArray() {
+ for _, part := range content.Array() {
+ partType := part.Get("type").String()
+ switch partType {
+ case "text":
+ contentBuilder.WriteString(part.Get("text").String())
+ case "image_url":
+ imageURL := part.Get("image_url.url").String()
+ if strings.HasPrefix(imageURL, "data:") {
+ // Parse data URL: data:image/png;base64,xxxxx
+ if idx := strings.Index(imageURL, ";base64,"); idx != -1 {
+ mediaType := imageURL[5:idx] // Skip "data:"
+ data := imageURL[idx+8:] // Skip ";base64,"
+
+ format := ""
+ if lastSlash := strings.LastIndex(mediaType, "/"); lastSlash != -1 {
+ format = mediaType[lastSlash+1:]
+ }
+
+ if format != "" && data != "" {
+ images = append(images, KiroImage{
+ Format: format,
+ Source: KiroImageSource{
+ Bytes: data,
+ },
+ })
+ }
+ }
+ }
+ }
+ }
+ } else if content.Type == gjson.String {
+ contentBuilder.WriteString(content.String())
+ }
+
+ userMsg := KiroUserInputMessage{
+ Content: contentBuilder.String(),
+ ModelID: modelID,
+ Origin: origin,
+ }
+
+ if len(images) > 0 {
+ userMsg.Images = images
+ }
+
+ return userMsg, toolResults
+}
+
+// buildAssistantMessageFromOpenAI builds an assistant message from OpenAI format
+func buildAssistantMessageFromOpenAI(msg gjson.Result) KiroAssistantResponseMessage {
+ content := msg.Get("content")
+ var contentBuilder strings.Builder
+ var toolUses []KiroToolUse
+
+ // Handle content
+ if content.Type == gjson.String {
+ contentBuilder.WriteString(content.String())
+ } else if content.IsArray() {
+ for _, part := range content.Array() {
+ partType := part.Get("type").String()
+ switch partType {
+ case "text":
+ contentBuilder.WriteString(part.Get("text").String())
+ case "tool_use":
+ // Handle tool_use in content array (Anthropic/OpenCode format)
+ // This is different from OpenAI's tool_calls format
+ toolUseID := part.Get("id").String()
+ toolName := part.Get("name").String()
+ inputData := part.Get("input")
+
+ inputMap := make(map[string]interface{})
+ if inputData.Exists() && inputData.IsObject() {
+ inputData.ForEach(func(key, value gjson.Result) bool {
+ inputMap[key.String()] = value.Value()
+ return true
+ })
+ }
+
+ toolUses = append(toolUses, KiroToolUse{
+ ToolUseID: toolUseID,
+ Name: toolName,
+ Input: inputMap,
+ })
+ log.Debugf("kiro-openai: extracted tool_use from content array: %s", toolName)
+ }
+ }
+ }
+
+ // Handle tool_calls (OpenAI format)
+ toolCalls := msg.Get("tool_calls")
+ if toolCalls.IsArray() {
+ for _, tc := range toolCalls.Array() {
+ if tc.Get("type").String() != "function" {
+ continue
+ }
+
+ toolUseID := tc.Get("id").String()
+ toolName := tc.Get("function.name").String()
+ toolArgs := tc.Get("function.arguments").String()
+
+ var inputMap map[string]interface{}
+ if err := json.Unmarshal([]byte(toolArgs), &inputMap); err != nil {
+ log.Debugf("kiro-openai: failed to parse tool arguments: %v", err)
+ inputMap = make(map[string]interface{})
+ }
+
+ toolUses = append(toolUses, KiroToolUse{
+ ToolUseID: toolUseID,
+ Name: toolName,
+ Input: inputMap,
+ })
+ }
+ }
+
+ // CRITICAL FIX: Kiro API requires non-empty content for assistant messages
+ // This can happen with compaction requests or error recovery scenarios
+ finalContent := contentBuilder.String()
+ if strings.TrimSpace(finalContent) == "" {
+ if len(toolUses) > 0 {
+ finalContent = kirocommon.DefaultAssistantContentWithTools
+ } else {
+ finalContent = kirocommon.DefaultAssistantContent
+ }
+ log.Debugf("kiro-openai: assistant content was empty, using default: %s", finalContent)
+ }
+
+ return KiroAssistantResponseMessage{
+ Content: finalContent,
+ ToolUses: toolUses,
+ }
+}
+
+// buildFinalContent builds the final content with system prompt
+func buildFinalContent(content, systemPrompt string, toolResults []KiroToolResult) string {
+ var contentBuilder strings.Builder
+
+ if systemPrompt != "" {
+ contentBuilder.WriteString("--- SYSTEM PROMPT ---\n")
+ contentBuilder.WriteString(systemPrompt)
+ contentBuilder.WriteString("\n--- END SYSTEM PROMPT ---\n\n")
+ }
+
+ contentBuilder.WriteString(content)
+ finalContent := contentBuilder.String()
+
+ // CRITICAL: Kiro API requires content to be non-empty
+ if strings.TrimSpace(finalContent) == "" {
+ if len(toolResults) > 0 {
+ finalContent = "Tool results provided."
+ } else {
+ finalContent = "Continue"
+ }
+ log.Debugf("kiro-openai: content was empty, using default: %s", finalContent)
+ }
+
+ return finalContent
+}
+
+// checkThinkingModeFromOpenAI checks if thinking mode is enabled in the OpenAI request.
+// Returns thinkingEnabled.
+// Supports:
+// - reasoning_effort parameter (low/medium/high/auto)
+// - Model name containing "thinking" or "reason"
+// - tag in system prompt (AMP/Cursor format)
+func checkThinkingModeFromOpenAI(openaiBody []byte) bool {
+ return checkThinkingModeFromOpenAIWithHeaders(openaiBody, nil)
+}
+
+// checkThinkingModeFromOpenAIWithHeaders checks if thinking mode is enabled in the OpenAI request.
+// Returns thinkingEnabled.
+// Supports:
+// - Anthropic-Beta header with interleaved-thinking (Claude CLI)
+// - reasoning_effort parameter (low/medium/high/auto)
+// - Model name containing "thinking" or "reason"
+// - tag in system prompt (AMP/Cursor format)
+func checkThinkingModeFromOpenAIWithHeaders(openaiBody []byte, headers http.Header) bool {
+ // Check Anthropic-Beta header first (Claude CLI uses this)
+ if kiroclaude.IsThinkingEnabledFromHeader(headers) {
+ log.Debugf("kiro-openai: thinking mode enabled via Anthropic-Beta header")
+ return true
+ }
+
+ // Check OpenAI format: reasoning_effort parameter
+ // Valid values: "low", "medium", "high", "auto" (not "none")
+ reasoningEffort := gjson.GetBytes(openaiBody, "reasoning_effort")
+ if reasoningEffort.Exists() {
+ effort := reasoningEffort.String()
+ if effort != "" && effort != "none" {
+ log.Debugf("kiro-openai: thinking mode enabled via reasoning_effort: %s", effort)
+ return true
+ }
+ }
+
+ // Check AMP/Cursor format: interleaved in system prompt
+ bodyStr := string(openaiBody)
+ if strings.Contains(bodyStr, "") && strings.Contains(bodyStr, "") {
+ startTag := ""
+ endTag := ""
+ startIdx := strings.Index(bodyStr, startTag)
+ if startIdx >= 0 {
+ startIdx += len(startTag)
+ endIdx := strings.Index(bodyStr[startIdx:], endTag)
+ if endIdx >= 0 {
+ thinkingMode := bodyStr[startIdx : startIdx+endIdx]
+ if thinkingMode == "interleaved" || thinkingMode == "enabled" {
+ log.Debugf("kiro-openai: thinking mode enabled via AMP/Cursor format: %s", thinkingMode)
+ return true
+ }
+ }
+ }
+ }
+
+ // Check model name for thinking hints
+ model := gjson.GetBytes(openaiBody, "model").String()
+ modelLower := strings.ToLower(model)
+ if strings.Contains(modelLower, "thinking") || strings.Contains(modelLower, "-reason") {
+ log.Debugf("kiro-openai: thinking mode enabled via model name hint: %s", model)
+ return true
+ }
+
+ log.Debugf("kiro-openai: no thinking mode detected in OpenAI request")
+ return false
+}
+
+// hasThinkingTagInBody checks if the request body already contains thinking configuration tags.
+// This is used to prevent duplicate injection when client (e.g., AMP/Cursor) already includes thinking config.
+func hasThinkingTagInBody(body []byte) bool {
+ bodyStr := string(body)
+ return strings.Contains(bodyStr, "") || strings.Contains(bodyStr, "")
+}
+
+
+// extractToolChoiceHint extracts tool_choice from OpenAI request and returns a system prompt hint.
+// OpenAI tool_choice values:
+// - "none": Don't use any tools
+// - "auto": Model decides (default, no hint needed)
+// - "required": Must use at least one tool
+// - {"type":"function","function":{"name":"..."}} : Must use specific tool
+func extractToolChoiceHint(openaiBody []byte) string {
+ toolChoice := gjson.GetBytes(openaiBody, "tool_choice")
+ if !toolChoice.Exists() {
+ return ""
+ }
+
+ // Handle string values
+ if toolChoice.Type == gjson.String {
+ switch toolChoice.String() {
+ case "none":
+ // Note: When tool_choice is "none", we should ideally not pass tools at all
+ // But since we can't modify tool passing here, we add a strong hint
+ return "[INSTRUCTION: Do NOT use any tools. Respond with text only.]"
+ case "required":
+ return "[INSTRUCTION: You MUST use at least one of the available tools to respond. Do not respond with text only - always make a tool call.]"
+ case "auto":
+ // Default behavior, no hint needed
+ return ""
+ }
+ }
+
+ // Handle object value: {"type":"function","function":{"name":"..."}}
+ if toolChoice.IsObject() {
+ if toolChoice.Get("type").String() == "function" {
+ toolName := toolChoice.Get("function.name").String()
+ if toolName != "" {
+ return fmt.Sprintf("[INSTRUCTION: You MUST use the tool named '%s' to respond. Do not use any other tool or respond with text only.]", toolName)
+ }
+ }
+ }
+
+ return ""
+}
+
+// extractResponseFormatHint extracts response_format from OpenAI request and returns a system prompt hint.
+// OpenAI response_format values:
+// - {"type": "text"}: Default, no hint needed
+// - {"type": "json_object"}: Must respond with valid JSON
+// - {"type": "json_schema", "json_schema": {...}}: Must respond with JSON matching schema
+func extractResponseFormatHint(openaiBody []byte) string {
+ responseFormat := gjson.GetBytes(openaiBody, "response_format")
+ if !responseFormat.Exists() {
+ return ""
+ }
+
+ formatType := responseFormat.Get("type").String()
+ switch formatType {
+ case "json_object":
+ return "[INSTRUCTION: You MUST respond with valid JSON only. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]"
+ case "json_schema":
+ // Extract schema if provided
+ schema := responseFormat.Get("json_schema.schema")
+ if schema.Exists() {
+ schemaStr := schema.Raw
+ // Truncate if too long
+ if len(schemaStr) > 500 {
+ schemaStr = schemaStr[:500] + "..."
+ }
+ return fmt.Sprintf("[INSTRUCTION: You MUST respond with valid JSON that matches this schema: %s. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]", schemaStr)
+ }
+ return "[INSTRUCTION: You MUST respond with valid JSON only. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]"
+ case "text":
+ // Default behavior, no hint needed
+ return ""
+ }
+
+ return ""
+}
+
+// deduplicateToolResults removes duplicate tool results
+func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult {
+ if len(toolResults) == 0 {
+ return toolResults
+ }
+
+ seenIDs := make(map[string]bool)
+ unique := make([]KiroToolResult, 0, len(toolResults))
+ for _, tr := range toolResults {
+ if !seenIDs[tr.ToolUseID] {
+ seenIDs[tr.ToolUseID] = true
+ unique = append(unique, tr)
+ } else {
+ log.Debugf("kiro-openai: skipping duplicate toolResult: %s", tr.ToolUseID)
+ }
+ }
+ return unique
+}
diff --git a/internal/translator/kiro/openai/kiro_openai_request_test.go b/internal/translator/kiro/openai/kiro_openai_request_test.go
new file mode 100644
index 00000000..85e95d4a
--- /dev/null
+++ b/internal/translator/kiro/openai/kiro_openai_request_test.go
@@ -0,0 +1,386 @@
+package openai
+
+import (
+ "encoding/json"
+ "testing"
+)
+
+// TestToolResultsAttachedToCurrentMessage verifies that tool results from "tool" role messages
+// are properly attached to the current user message (the last message in the conversation).
+// This is critical for LiteLLM-translated requests where tool results appear as separate messages.
+func TestToolResultsAttachedToCurrentMessage(t *testing.T) {
+ // OpenAI format request simulating LiteLLM's translation from Anthropic format
+ // Sequence: user -> assistant (with tool_calls) -> tool (result) -> user
+ // The last user message should have the tool results attached
+ input := []byte(`{
+ "model": "kiro-claude-opus-4-5-agentic",
+ "messages": [
+ {"role": "user", "content": "Hello, can you read a file for me?"},
+ {
+ "role": "assistant",
+ "content": "I'll read that file for you.",
+ "tool_calls": [
+ {
+ "id": "call_abc123",
+ "type": "function",
+ "function": {
+ "name": "Read",
+ "arguments": "{\"file_path\": \"/tmp/test.txt\"}"
+ }
+ }
+ ]
+ },
+ {
+ "role": "tool",
+ "tool_call_id": "call_abc123",
+ "content": "File contents: Hello World!"
+ },
+ {"role": "user", "content": "What did the file say?"}
+ ]
+ }`)
+
+ result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
+
+ var payload KiroPayload
+ if err := json.Unmarshal(result, &payload); err != nil {
+ t.Fatalf("Failed to unmarshal result: %v", err)
+ }
+
+ // The last user message becomes currentMessage
+ // History should have: user (first), assistant (with tool_calls)
+ t.Logf("History count: %d", len(payload.ConversationState.History))
+ if len(payload.ConversationState.History) != 2 {
+ t.Errorf("Expected 2 history entries (user + assistant), got %d", len(payload.ConversationState.History))
+ }
+
+ // Tool results should be attached to currentMessage (the last user message)
+ ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
+ if ctx == nil {
+ t.Fatal("Expected currentMessage to have UserInputMessageContext with tool results")
+ }
+
+ if len(ctx.ToolResults) != 1 {
+ t.Fatalf("Expected 1 tool result in currentMessage, got %d", len(ctx.ToolResults))
+ }
+
+ tr := ctx.ToolResults[0]
+ if tr.ToolUseID != "call_abc123" {
+ t.Errorf("Expected toolUseId 'call_abc123', got '%s'", tr.ToolUseID)
+ }
+ if len(tr.Content) == 0 || tr.Content[0].Text != "File contents: Hello World!" {
+ t.Errorf("Tool result content mismatch, got: %+v", tr.Content)
+ }
+}
+
+// TestToolResultsInHistoryUserMessage verifies that when there are multiple user messages
+// after tool results, the tool results are attached to the correct user message in history.
+func TestToolResultsInHistoryUserMessage(t *testing.T) {
+ // Sequence: user -> assistant (with tool_calls) -> tool (result) -> user -> assistant -> user
+ // The first user after tool should have tool results in history
+ input := []byte(`{
+ "model": "kiro-claude-opus-4-5-agentic",
+ "messages": [
+ {"role": "user", "content": "Hello"},
+ {
+ "role": "assistant",
+ "content": "I'll read the file.",
+ "tool_calls": [
+ {
+ "id": "call_1",
+ "type": "function",
+ "function": {
+ "name": "Read",
+ "arguments": "{}"
+ }
+ }
+ ]
+ },
+ {
+ "role": "tool",
+ "tool_call_id": "call_1",
+ "content": "File result"
+ },
+ {"role": "user", "content": "Thanks for the file"},
+ {"role": "assistant", "content": "You're welcome"},
+ {"role": "user", "content": "Bye"}
+ ]
+ }`)
+
+ result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
+
+ var payload KiroPayload
+ if err := json.Unmarshal(result, &payload); err != nil {
+ t.Fatalf("Failed to unmarshal result: %v", err)
+ }
+
+ // History should have: user, assistant, user (with tool results), assistant
+ // CurrentMessage should be: last user "Bye"
+ t.Logf("History count: %d", len(payload.ConversationState.History))
+
+ // Find the user message in history with tool results
+ foundToolResults := false
+ for i, h := range payload.ConversationState.History {
+ if h.UserInputMessage != nil {
+ t.Logf("History[%d]: user message content=%q", i, h.UserInputMessage.Content)
+ if h.UserInputMessage.UserInputMessageContext != nil {
+ if len(h.UserInputMessage.UserInputMessageContext.ToolResults) > 0 {
+ foundToolResults = true
+ t.Logf(" Found %d tool results", len(h.UserInputMessage.UserInputMessageContext.ToolResults))
+ tr := h.UserInputMessage.UserInputMessageContext.ToolResults[0]
+ if tr.ToolUseID != "call_1" {
+ t.Errorf("Expected toolUseId 'call_1', got '%s'", tr.ToolUseID)
+ }
+ }
+ }
+ }
+ if h.AssistantResponseMessage != nil {
+ t.Logf("History[%d]: assistant message content=%q", i, h.AssistantResponseMessage.Content)
+ }
+ }
+
+ if !foundToolResults {
+ t.Error("Tool results were not attached to any user message in history")
+ }
+}
+
+// TestToolResultsWithMultipleToolCalls verifies handling of multiple tool calls
+func TestToolResultsWithMultipleToolCalls(t *testing.T) {
+ input := []byte(`{
+ "model": "kiro-claude-opus-4-5-agentic",
+ "messages": [
+ {"role": "user", "content": "Read two files for me"},
+ {
+ "role": "assistant",
+ "content": "I'll read both files.",
+ "tool_calls": [
+ {
+ "id": "call_1",
+ "type": "function",
+ "function": {
+ "name": "Read",
+ "arguments": "{\"file_path\": \"/tmp/file1.txt\"}"
+ }
+ },
+ {
+ "id": "call_2",
+ "type": "function",
+ "function": {
+ "name": "Read",
+ "arguments": "{\"file_path\": \"/tmp/file2.txt\"}"
+ }
+ }
+ ]
+ },
+ {
+ "role": "tool",
+ "tool_call_id": "call_1",
+ "content": "Content of file 1"
+ },
+ {
+ "role": "tool",
+ "tool_call_id": "call_2",
+ "content": "Content of file 2"
+ },
+ {"role": "user", "content": "What do they say?"}
+ ]
+ }`)
+
+ result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
+
+ var payload KiroPayload
+ if err := json.Unmarshal(result, &payload); err != nil {
+ t.Fatalf("Failed to unmarshal result: %v", err)
+ }
+
+ t.Logf("History count: %d", len(payload.ConversationState.History))
+ t.Logf("CurrentMessage content: %q", payload.ConversationState.CurrentMessage.UserInputMessage.Content)
+
+ // Check if there are any tool results anywhere
+ var totalToolResults int
+ for i, h := range payload.ConversationState.History {
+ if h.UserInputMessage != nil && h.UserInputMessage.UserInputMessageContext != nil {
+ count := len(h.UserInputMessage.UserInputMessageContext.ToolResults)
+ t.Logf("History[%d] user message has %d tool results", i, count)
+ totalToolResults += count
+ }
+ }
+
+ ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
+ if ctx != nil {
+ t.Logf("CurrentMessage has %d tool results", len(ctx.ToolResults))
+ totalToolResults += len(ctx.ToolResults)
+ } else {
+ t.Logf("CurrentMessage has no UserInputMessageContext")
+ }
+
+ if totalToolResults != 2 {
+ t.Errorf("Expected 2 tool results total, got %d", totalToolResults)
+ }
+}
+
+// TestToolResultsAtEndOfConversation verifies tool results are handled when
+// the conversation ends with tool results (no following user message)
+func TestToolResultsAtEndOfConversation(t *testing.T) {
+ input := []byte(`{
+ "model": "kiro-claude-opus-4-5-agentic",
+ "messages": [
+ {"role": "user", "content": "Read a file"},
+ {
+ "role": "assistant",
+ "content": "Reading the file.",
+ "tool_calls": [
+ {
+ "id": "call_end",
+ "type": "function",
+ "function": {
+ "name": "Read",
+ "arguments": "{\"file_path\": \"/tmp/test.txt\"}"
+ }
+ }
+ ]
+ },
+ {
+ "role": "tool",
+ "tool_call_id": "call_end",
+ "content": "File contents here"
+ }
+ ]
+ }`)
+
+ result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
+
+ var payload KiroPayload
+ if err := json.Unmarshal(result, &payload); err != nil {
+ t.Fatalf("Failed to unmarshal result: %v", err)
+ }
+
+ // When the last message is a tool result, a synthetic user message is created
+ // and tool results should be attached to it
+ ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
+ if ctx == nil || len(ctx.ToolResults) == 0 {
+ t.Error("Expected tool results to be attached to current message when conversation ends with tool result")
+ } else {
+ if ctx.ToolResults[0].ToolUseID != "call_end" {
+ t.Errorf("Expected toolUseId 'call_end', got '%s'", ctx.ToolResults[0].ToolUseID)
+ }
+ }
+}
+
+// TestToolResultsFollowedByAssistant verifies handling when tool results are followed
+// by an assistant message (no intermediate user message).
+// This is the pattern from LiteLLM translation of Anthropic format where:
+// user message has ONLY tool_result blocks -> LiteLLM creates tool messages
+// then the next message is assistant
+func TestToolResultsFollowedByAssistant(t *testing.T) {
+ // Sequence: user -> assistant (with tool_calls) -> tool -> tool -> assistant -> user
+ // This simulates LiteLLM's translation of:
+ // user: "Read files"
+ // assistant: [tool_use, tool_use]
+ // user: [tool_result, tool_result] <- becomes multiple "tool" role messages
+ // assistant: "I've read them"
+ // user: "What did they say?"
+ input := []byte(`{
+ "model": "kiro-claude-opus-4-5-agentic",
+ "messages": [
+ {"role": "user", "content": "Read two files for me"},
+ {
+ "role": "assistant",
+ "content": "I'll read both files.",
+ "tool_calls": [
+ {
+ "id": "call_1",
+ "type": "function",
+ "function": {
+ "name": "Read",
+ "arguments": "{\"file_path\": \"/tmp/a.txt\"}"
+ }
+ },
+ {
+ "id": "call_2",
+ "type": "function",
+ "function": {
+ "name": "Read",
+ "arguments": "{\"file_path\": \"/tmp/b.txt\"}"
+ }
+ }
+ ]
+ },
+ {
+ "role": "tool",
+ "tool_call_id": "call_1",
+ "content": "Contents of file A"
+ },
+ {
+ "role": "tool",
+ "tool_call_id": "call_2",
+ "content": "Contents of file B"
+ },
+ {
+ "role": "assistant",
+ "content": "I've read both files."
+ },
+ {"role": "user", "content": "What did they say?"}
+ ]
+ }`)
+
+ result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
+
+ var payload KiroPayload
+ if err := json.Unmarshal(result, &payload); err != nil {
+ t.Fatalf("Failed to unmarshal result: %v", err)
+ }
+
+ t.Logf("History count: %d", len(payload.ConversationState.History))
+
+ // Tool results should be attached to a synthetic user message or the history should be valid
+ var totalToolResults int
+ for i, h := range payload.ConversationState.History {
+ if h.UserInputMessage != nil {
+ t.Logf("History[%d]: user message content=%q", i, h.UserInputMessage.Content)
+ if h.UserInputMessage.UserInputMessageContext != nil {
+ count := len(h.UserInputMessage.UserInputMessageContext.ToolResults)
+ t.Logf(" Has %d tool results", count)
+ totalToolResults += count
+ }
+ }
+ if h.AssistantResponseMessage != nil {
+ t.Logf("History[%d]: assistant message content=%q", i, h.AssistantResponseMessage.Content)
+ }
+ }
+
+ ctx := payload.ConversationState.CurrentMessage.UserInputMessage.UserInputMessageContext
+ if ctx != nil {
+ t.Logf("CurrentMessage has %d tool results", len(ctx.ToolResults))
+ totalToolResults += len(ctx.ToolResults)
+ }
+
+ if totalToolResults != 2 {
+ t.Errorf("Expected 2 tool results total, got %d", totalToolResults)
+ }
+}
+
+// TestAssistantEndsConversation verifies handling when assistant is the last message
+func TestAssistantEndsConversation(t *testing.T) {
+ input := []byte(`{
+ "model": "kiro-claude-opus-4-5-agentic",
+ "messages": [
+ {"role": "user", "content": "Hello"},
+ {
+ "role": "assistant",
+ "content": "Hi there!"
+ }
+ ]
+ }`)
+
+ result, _ := BuildKiroPayloadFromOpenAI(input, "kiro-model", "", "CLI", false, false, nil, nil)
+
+ var payload KiroPayload
+ if err := json.Unmarshal(result, &payload); err != nil {
+ t.Fatalf("Failed to unmarshal result: %v", err)
+ }
+
+ // When assistant is last, a "Continue" user message should be created
+ if payload.ConversationState.CurrentMessage.UserInputMessage.Content == "" {
+ t.Error("Expected a 'Continue' message to be created when assistant is last")
+ }
+}
diff --git a/internal/translator/kiro/openai/kiro_openai_response.go b/internal/translator/kiro/openai/kiro_openai_response.go
new file mode 100644
index 00000000..edc70ad8
--- /dev/null
+++ b/internal/translator/kiro/openai/kiro_openai_response.go
@@ -0,0 +1,277 @@
+// Package openai provides response translation from Kiro to OpenAI format.
+// This package handles the conversion of Kiro API responses into OpenAI Chat Completions-compatible
+// JSON format, transforming streaming events and non-streaming responses.
+package openai
+
+import (
+ "encoding/json"
+ "fmt"
+ "sync/atomic"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
+ log "github.com/sirupsen/logrus"
+)
+
+// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
+var functionCallIDCounter uint64
+
+// BuildOpenAIResponse constructs an OpenAI Chat Completions-compatible response.
+// Supports tool_calls when tools are present in the response.
+// stopReason is passed from upstream; fallback logic applied if empty.
+func BuildOpenAIResponse(content string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte {
+ return BuildOpenAIResponseWithReasoning(content, "", toolUses, model, usageInfo, stopReason)
+}
+
+// BuildOpenAIResponseWithReasoning constructs an OpenAI Chat Completions-compatible response with reasoning_content support.
+// Supports tool_calls when tools are present in the response.
+// reasoningContent is included as reasoning_content field in the message when present.
+// stopReason is passed from upstream; fallback logic applied if empty.
+func BuildOpenAIResponseWithReasoning(content, reasoningContent string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte {
+ // Build the message object
+ message := map[string]interface{}{
+ "role": "assistant",
+ "content": content,
+ }
+
+ // Add reasoning_content if present (for thinking/reasoning models)
+ if reasoningContent != "" {
+ message["reasoning_content"] = reasoningContent
+ }
+
+ // Add tool_calls if present
+ if len(toolUses) > 0 {
+ var toolCalls []map[string]interface{}
+ for i, tu := range toolUses {
+ inputJSON, _ := json.Marshal(tu.Input)
+ toolCalls = append(toolCalls, map[string]interface{}{
+ "id": tu.ToolUseID,
+ "type": "function",
+ "index": i,
+ "function": map[string]interface{}{
+ "name": tu.Name,
+ "arguments": string(inputJSON),
+ },
+ })
+ }
+ message["tool_calls"] = toolCalls
+ // When tool_calls are present, content should be null according to OpenAI spec
+ if content == "" {
+ message["content"] = nil
+ }
+ }
+
+ // Use upstream stopReason; apply fallback logic if not provided
+ finishReason := mapKiroStopReasonToOpenAI(stopReason)
+ if finishReason == "" {
+ finishReason = "stop"
+ if len(toolUses) > 0 {
+ finishReason = "tool_calls"
+ }
+ log.Debugf("kiro-openai: buildOpenAIResponse using fallback finish_reason: %s", finishReason)
+ }
+
+ response := map[string]interface{}{
+ "id": "chatcmpl-" + uuid.New().String()[:24],
+ "object": "chat.completion",
+ "created": time.Now().Unix(),
+ "model": model,
+ "choices": []map[string]interface{}{
+ {
+ "index": 0,
+ "message": message,
+ "finish_reason": finishReason,
+ },
+ },
+ "usage": map[string]interface{}{
+ "prompt_tokens": usageInfo.InputTokens,
+ "completion_tokens": usageInfo.OutputTokens,
+ "total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens,
+ },
+ }
+
+ result, _ := json.Marshal(response)
+ return result
+}
+
+// mapKiroStopReasonToOpenAI converts Kiro/Claude stop_reason to OpenAI finish_reason
+func mapKiroStopReasonToOpenAI(stopReason string) string {
+ switch stopReason {
+ case "end_turn":
+ return "stop"
+ case "stop_sequence":
+ return "stop"
+ case "tool_use":
+ return "tool_calls"
+ case "max_tokens":
+ return "length"
+ case "content_filtered":
+ return "content_filter"
+ default:
+ return stopReason
+ }
+}
+
+// BuildOpenAIStreamChunk constructs an OpenAI Chat Completions streaming chunk.
+// This is the delta format used in streaming responses.
+func BuildOpenAIStreamChunk(model string, deltaContent string, deltaToolCalls []map[string]interface{}, finishReason string, index int) []byte {
+ delta := map[string]interface{}{}
+
+ // First chunk should include role
+ if index == 0 && deltaContent == "" && len(deltaToolCalls) == 0 {
+ delta["role"] = "assistant"
+ delta["content"] = ""
+ } else if deltaContent != "" {
+ delta["content"] = deltaContent
+ }
+
+ // Add tool_calls delta if present
+ if len(deltaToolCalls) > 0 {
+ delta["tool_calls"] = deltaToolCalls
+ }
+
+ choice := map[string]interface{}{
+ "index": 0,
+ "delta": delta,
+ }
+
+ if finishReason != "" {
+ choice["finish_reason"] = finishReason
+ } else {
+ choice["finish_reason"] = nil
+ }
+
+ chunk := map[string]interface{}{
+ "id": "chatcmpl-" + uuid.New().String()[:12],
+ "object": "chat.completion.chunk",
+ "created": time.Now().Unix(),
+ "model": model,
+ "choices": []map[string]interface{}{choice},
+ }
+
+ result, _ := json.Marshal(chunk)
+ return result
+}
+
+// BuildOpenAIStreamChunkWithToolCallStart creates a stream chunk for tool call start
+func BuildOpenAIStreamChunkWithToolCallStart(model string, toolUseID, toolName string, toolIndex int) []byte {
+ toolCall := map[string]interface{}{
+ "index": toolIndex,
+ "id": toolUseID,
+ "type": "function",
+ "function": map[string]interface{}{
+ "name": toolName,
+ "arguments": "",
+ },
+ }
+
+ delta := map[string]interface{}{
+ "tool_calls": []map[string]interface{}{toolCall},
+ }
+
+ choice := map[string]interface{}{
+ "index": 0,
+ "delta": delta,
+ "finish_reason": nil,
+ }
+
+ chunk := map[string]interface{}{
+ "id": "chatcmpl-" + uuid.New().String()[:12],
+ "object": "chat.completion.chunk",
+ "created": time.Now().Unix(),
+ "model": model,
+ "choices": []map[string]interface{}{choice},
+ }
+
+ result, _ := json.Marshal(chunk)
+ return result
+}
+
+// BuildOpenAIStreamChunkWithToolCallDelta creates a stream chunk for tool call arguments delta
+func BuildOpenAIStreamChunkWithToolCallDelta(model string, argumentsDelta string, toolIndex int) []byte {
+ toolCall := map[string]interface{}{
+ "index": toolIndex,
+ "function": map[string]interface{}{
+ "arguments": argumentsDelta,
+ },
+ }
+
+ delta := map[string]interface{}{
+ "tool_calls": []map[string]interface{}{toolCall},
+ }
+
+ choice := map[string]interface{}{
+ "index": 0,
+ "delta": delta,
+ "finish_reason": nil,
+ }
+
+ chunk := map[string]interface{}{
+ "id": "chatcmpl-" + uuid.New().String()[:12],
+ "object": "chat.completion.chunk",
+ "created": time.Now().Unix(),
+ "model": model,
+ "choices": []map[string]interface{}{choice},
+ }
+
+ result, _ := json.Marshal(chunk)
+ return result
+}
+
+// BuildOpenAIStreamDoneChunk creates the final [DONE] stream event
+func BuildOpenAIStreamDoneChunk() []byte {
+ return []byte("data: [DONE]")
+}
+
+// BuildOpenAIStreamFinishChunk creates the final chunk with finish_reason
+func BuildOpenAIStreamFinishChunk(model string, finishReason string) []byte {
+ choice := map[string]interface{}{
+ "index": 0,
+ "delta": map[string]interface{}{},
+ "finish_reason": finishReason,
+ }
+
+ chunk := map[string]interface{}{
+ "id": "chatcmpl-" + uuid.New().String()[:12],
+ "object": "chat.completion.chunk",
+ "created": time.Now().Unix(),
+ "model": model,
+ "choices": []map[string]interface{}{choice},
+ }
+
+ result, _ := json.Marshal(chunk)
+ return result
+}
+
+// BuildOpenAIStreamUsageChunk creates a chunk with usage information (optional, for stream_options.include_usage)
+func BuildOpenAIStreamUsageChunk(model string, usageInfo usage.Detail) []byte {
+ chunk := map[string]interface{}{
+ "id": "chatcmpl-" + uuid.New().String()[:12],
+ "object": "chat.completion.chunk",
+ "created": time.Now().Unix(),
+ "model": model,
+ "choices": []map[string]interface{}{},
+ "usage": map[string]interface{}{
+ "prompt_tokens": usageInfo.InputTokens,
+ "completion_tokens": usageInfo.OutputTokens,
+ "total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens,
+ },
+ }
+
+ result, _ := json.Marshal(chunk)
+ return result
+}
+
+// GenerateToolCallID generates a unique tool call ID in OpenAI format
+func GenerateToolCallID(toolName string) string {
+ return fmt.Sprintf("call_%s_%d_%d", toolName[:min(8, len(toolName))], time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1))
+}
+
+// min returns the minimum of two integers
+func min(a, b int) int {
+ if a < b {
+ return a
+ }
+ return b
+}
\ No newline at end of file
diff --git a/internal/translator/kiro/openai/kiro_openai_stream.go b/internal/translator/kiro/openai/kiro_openai_stream.go
new file mode 100644
index 00000000..e72d970e
--- /dev/null
+++ b/internal/translator/kiro/openai/kiro_openai_stream.go
@@ -0,0 +1,212 @@
+// Package openai provides streaming SSE event building for OpenAI format.
+// This package handles the construction of OpenAI-compatible Server-Sent Events (SSE)
+// for streaming responses from Kiro API.
+package openai
+
+import (
+ "encoding/json"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
+)
+
+// OpenAIStreamState tracks the state of streaming response conversion
+type OpenAIStreamState struct {
+ ChunkIndex int
+ ToolCallIndex int
+ HasSentFirstChunk bool
+ Model string
+ ResponseID string
+ Created int64
+}
+
+// NewOpenAIStreamState creates a new stream state for tracking
+func NewOpenAIStreamState(model string) *OpenAIStreamState {
+ return &OpenAIStreamState{
+ ChunkIndex: 0,
+ ToolCallIndex: 0,
+ HasSentFirstChunk: false,
+ Model: model,
+ ResponseID: "chatcmpl-" + uuid.New().String()[:24],
+ Created: time.Now().Unix(),
+ }
+}
+
+// FormatSSEEvent formats a JSON payload for SSE streaming.
+// Note: This returns raw JSON data without "data:" prefix.
+// The SSE "data:" prefix is added by the Handler layer (e.g., openai_handlers.go)
+// to maintain architectural consistency and avoid double-prefix issues.
+func FormatSSEEvent(data []byte) string {
+ return string(data)
+}
+
+// BuildOpenAISSETextDelta creates an SSE event for text content delta
+func BuildOpenAISSETextDelta(state *OpenAIStreamState, textDelta string) string {
+ delta := map[string]interface{}{
+ "content": textDelta,
+ }
+
+ // Include role in first chunk
+ if !state.HasSentFirstChunk {
+ delta["role"] = "assistant"
+ state.HasSentFirstChunk = true
+ }
+
+ chunk := buildBaseChunk(state, delta, nil)
+ result, _ := json.Marshal(chunk)
+ state.ChunkIndex++
+ return FormatSSEEvent(result)
+}
+
+// BuildOpenAISSEToolCallStart creates an SSE event for tool call start
+func BuildOpenAISSEToolCallStart(state *OpenAIStreamState, toolUseID, toolName string) string {
+ toolCall := map[string]interface{}{
+ "index": state.ToolCallIndex,
+ "id": toolUseID,
+ "type": "function",
+ "function": map[string]interface{}{
+ "name": toolName,
+ "arguments": "",
+ },
+ }
+
+ delta := map[string]interface{}{
+ "tool_calls": []map[string]interface{}{toolCall},
+ }
+
+ // Include role in first chunk if not sent yet
+ if !state.HasSentFirstChunk {
+ delta["role"] = "assistant"
+ state.HasSentFirstChunk = true
+ }
+
+ chunk := buildBaseChunk(state, delta, nil)
+ result, _ := json.Marshal(chunk)
+ state.ChunkIndex++
+ return FormatSSEEvent(result)
+}
+
+// BuildOpenAISSEToolCallArgumentsDelta creates an SSE event for tool call arguments delta
+func BuildOpenAISSEToolCallArgumentsDelta(state *OpenAIStreamState, argumentsDelta string, toolIndex int) string {
+ toolCall := map[string]interface{}{
+ "index": toolIndex,
+ "function": map[string]interface{}{
+ "arguments": argumentsDelta,
+ },
+ }
+
+ delta := map[string]interface{}{
+ "tool_calls": []map[string]interface{}{toolCall},
+ }
+
+ chunk := buildBaseChunk(state, delta, nil)
+ result, _ := json.Marshal(chunk)
+ state.ChunkIndex++
+ return FormatSSEEvent(result)
+}
+
+// BuildOpenAISSEFinish creates an SSE event with finish_reason
+func BuildOpenAISSEFinish(state *OpenAIStreamState, finishReason string) string {
+ chunk := buildBaseChunk(state, map[string]interface{}{}, &finishReason)
+ result, _ := json.Marshal(chunk)
+ state.ChunkIndex++
+ return FormatSSEEvent(result)
+}
+
+// BuildOpenAISSEUsage creates an SSE event with usage information
+func BuildOpenAISSEUsage(state *OpenAIStreamState, usageInfo usage.Detail) string {
+ chunk := map[string]interface{}{
+ "id": state.ResponseID,
+ "object": "chat.completion.chunk",
+ "created": state.Created,
+ "model": state.Model,
+ "choices": []map[string]interface{}{},
+ "usage": map[string]interface{}{
+ "prompt_tokens": usageInfo.InputTokens,
+ "completion_tokens": usageInfo.OutputTokens,
+ "total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens,
+ },
+ }
+ result, _ := json.Marshal(chunk)
+ return FormatSSEEvent(result)
+}
+
+// BuildOpenAISSEDone creates the final [DONE] SSE event.
+// Note: This returns raw "[DONE]" without "data:" prefix.
+// The SSE "data:" prefix is added by the Handler layer (e.g., openai_handlers.go)
+// to maintain architectural consistency and avoid double-prefix issues.
+func BuildOpenAISSEDone() string {
+ return "[DONE]"
+}
+
+// buildBaseChunk creates a base chunk structure for streaming
+func buildBaseChunk(state *OpenAIStreamState, delta map[string]interface{}, finishReason *string) map[string]interface{} {
+ choice := map[string]interface{}{
+ "index": 0,
+ "delta": delta,
+ }
+
+ if finishReason != nil {
+ choice["finish_reason"] = *finishReason
+ } else {
+ choice["finish_reason"] = nil
+ }
+
+ return map[string]interface{}{
+ "id": state.ResponseID,
+ "object": "chat.completion.chunk",
+ "created": state.Created,
+ "model": state.Model,
+ "choices": []map[string]interface{}{choice},
+ }
+}
+
+// BuildOpenAISSEReasoningDelta creates an SSE event for reasoning content delta
+// This is used for o1/o3 style models that expose reasoning tokens
+func BuildOpenAISSEReasoningDelta(state *OpenAIStreamState, reasoningDelta string) string {
+ delta := map[string]interface{}{
+ "reasoning_content": reasoningDelta,
+ }
+
+ // Include role in first chunk
+ if !state.HasSentFirstChunk {
+ delta["role"] = "assistant"
+ state.HasSentFirstChunk = true
+ }
+
+ chunk := buildBaseChunk(state, delta, nil)
+ result, _ := json.Marshal(chunk)
+ state.ChunkIndex++
+ return FormatSSEEvent(result)
+}
+
+// BuildOpenAISSEFirstChunk creates the first chunk with role only
+func BuildOpenAISSEFirstChunk(state *OpenAIStreamState) string {
+ delta := map[string]interface{}{
+ "role": "assistant",
+ "content": "",
+ }
+
+ state.HasSentFirstChunk = true
+ chunk := buildBaseChunk(state, delta, nil)
+ result, _ := json.Marshal(chunk)
+ state.ChunkIndex++
+ return FormatSSEEvent(result)
+}
+
+// ThinkingTagState tracks state for thinking tag detection in streaming
+type ThinkingTagState struct {
+ InThinkingBlock bool
+ PendingStartChars int
+ PendingEndChars int
+}
+
+// NewThinkingTagState creates a new thinking tag state
+func NewThinkingTagState() *ThinkingTagState {
+ return &ThinkingTagState{
+ InThinkingBlock: false,
+ PendingStartChars: 0,
+ PendingEndChars: 0,
+ }
+}
\ No newline at end of file
diff --git a/internal/watcher/events.go b/internal/watcher/events.go
index 250cf75c..fb96ad2a 100644
--- a/internal/watcher/events.go
+++ b/internal/watcher/events.go
@@ -13,6 +13,7 @@ import (
"time"
"github.com/fsnotify/fsnotify"
+ kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
log "github.com/sirupsen/logrus"
)
@@ -39,12 +40,35 @@ func (w *Watcher) start(ctx context.Context) error {
}
log.Debugf("watching auth directory: %s", w.authDir)
+ w.watchKiroIDETokenFile()
+
go w.processEvents(ctx)
w.reloadClients(true, nil, false)
return nil
}
+func (w *Watcher) watchKiroIDETokenFile() {
+ homeDir, err := os.UserHomeDir()
+ if err != nil {
+ log.Debugf("failed to get home directory for Kiro IDE token watch: %v", err)
+ return
+ }
+
+ kiroTokenDir := filepath.Join(homeDir, ".aws", "sso", "cache")
+
+ if _, statErr := os.Stat(kiroTokenDir); os.IsNotExist(statErr) {
+ log.Debugf("Kiro IDE token directory does not exist: %s", kiroTokenDir)
+ return
+ }
+
+ if errAdd := w.watcher.Add(kiroTokenDir); errAdd != nil {
+ log.Debugf("failed to watch Kiro IDE token directory %s: %v", kiroTokenDir, errAdd)
+ return
+ }
+ log.Debugf("watching Kiro IDE token directory: %s", kiroTokenDir)
+}
+
func (w *Watcher) processEvents(ctx context.Context) {
for {
select {
@@ -73,11 +97,17 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0
- if !isConfigEvent && !isAuthJSON {
+ isKiroIDEToken := w.isKiroIDETokenFile(event.Name) && event.Op&authOps != 0
+ if !isConfigEvent && !isAuthJSON && !isKiroIDEToken {
// Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise.
return
}
+ if isKiroIDEToken {
+ w.handleKiroIDETokenChange(event)
+ return
+ }
+
now := time.Now()
log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name)
@@ -124,6 +154,44 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
}
}
+func (w *Watcher) isKiroIDETokenFile(path string) bool {
+ normalized := filepath.ToSlash(path)
+ return strings.HasSuffix(normalized, "kiro-auth-token.json") && strings.Contains(normalized, ".aws/sso/cache")
+}
+
+func (w *Watcher) handleKiroIDETokenChange(event fsnotify.Event) {
+ log.Debugf("Kiro IDE token file event detected: %s %s", event.Op.String(), event.Name)
+
+ if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
+ time.Sleep(replaceCheckDelay)
+ if _, statErr := os.Stat(event.Name); statErr != nil {
+ log.Debugf("Kiro IDE token file removed: %s", event.Name)
+ return
+ }
+ }
+
+ // Use retry logic to handle file lock contention (e.g., Kiro IDE writing the file)
+ // This prevents "being used by another process" errors on Windows
+ tokenData, err := kiroauth.LoadKiroIDETokenWithRetry(10, 50*time.Millisecond)
+ if err != nil {
+ log.Debugf("failed to load Kiro IDE token after change: %v", err)
+ return
+ }
+
+ log.Infof("Kiro IDE token file updated, access token refreshed (provider: %s)", tokenData.Provider)
+
+ w.refreshAuthState(true)
+
+ w.clientsMutex.RLock()
+ cfg := w.config
+ w.clientsMutex.RUnlock()
+
+ if w.reloadCallback != nil && cfg != nil {
+ log.Debugf("triggering server update callback after Kiro IDE token change")
+ w.reloadCallback(cfg)
+ }
+}
+
func (w *Watcher) authFileUnchanged(path string) (bool, error) {
data, errRead := os.ReadFile(path)
if errRead != nil {
diff --git a/internal/watcher/synthesizer/config.go b/internal/watcher/synthesizer/config.go
index 69194efc..e044117f 100644
--- a/internal/watcher/synthesizer/config.go
+++ b/internal/watcher/synthesizer/config.go
@@ -5,8 +5,10 @@ import (
"strconv"
"strings"
+ kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
+ log "github.com/sirupsen/logrus"
)
// ConfigSynthesizer generates Auth entries from configuration API keys.
@@ -31,6 +33,8 @@ func (s *ConfigSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth,
out = append(out, s.synthesizeClaudeKeys(ctx)...)
// Codex API Keys
out = append(out, s.synthesizeCodexKeys(ctx)...)
+ // Kiro (AWS CodeWhisperer)
+ out = append(out, s.synthesizeKiroKeys(ctx)...)
// OpenAI-compat
out = append(out, s.synthesizeOpenAICompat(ctx)...)
// Vertex-compat
@@ -320,3 +324,96 @@ func (s *ConfigSynthesizer) synthesizeVertexCompat(ctx *SynthesisContext) []*cor
}
return out
}
+
+// synthesizeKiroKeys creates Auth entries for Kiro (AWS CodeWhisperer) tokens.
+func (s *ConfigSynthesizer) synthesizeKiroKeys(ctx *SynthesisContext) []*coreauth.Auth {
+ cfg := ctx.Config
+ now := ctx.Now
+ idGen := ctx.IDGenerator
+
+ if len(cfg.KiroKey) == 0 {
+ return nil
+ }
+
+ out := make([]*coreauth.Auth, 0, len(cfg.KiroKey))
+ kAuth := kiroauth.NewKiroAuth(cfg)
+
+ for i := range cfg.KiroKey {
+ kk := cfg.KiroKey[i]
+ var accessToken, profileArn, refreshToken string
+
+ // Try to load from token file first
+ if kk.TokenFile != "" && kAuth != nil {
+ tokenData, err := kAuth.LoadTokenFromFile(kk.TokenFile)
+ if err != nil {
+ log.Warnf("failed to load kiro token file %s: %v", kk.TokenFile, err)
+ } else {
+ accessToken = tokenData.AccessToken
+ profileArn = tokenData.ProfileArn
+ refreshToken = tokenData.RefreshToken
+ }
+ }
+
+ // Override with direct config values if provided
+ if kk.AccessToken != "" {
+ accessToken = kk.AccessToken
+ }
+ if kk.ProfileArn != "" {
+ profileArn = kk.ProfileArn
+ }
+ if kk.RefreshToken != "" {
+ refreshToken = kk.RefreshToken
+ }
+
+ if accessToken == "" {
+ log.Warnf("kiro config[%d] missing access_token, skipping", i)
+ continue
+ }
+
+ // profileArn is optional for AWS Builder ID users
+ id, token := idGen.Next("kiro:token", accessToken, profileArn)
+ attrs := map[string]string{
+ "source": fmt.Sprintf("config:kiro[%s]", token),
+ "access_token": accessToken,
+ }
+ if profileArn != "" {
+ attrs["profile_arn"] = profileArn
+ }
+ if kk.Region != "" {
+ attrs["region"] = kk.Region
+ }
+ if kk.AgentTaskType != "" {
+ attrs["agent_task_type"] = kk.AgentTaskType
+ }
+ if kk.PreferredEndpoint != "" {
+ attrs["preferred_endpoint"] = kk.PreferredEndpoint
+ } else if cfg.KiroPreferredEndpoint != "" {
+ // Apply global default if not overridden by specific key
+ attrs["preferred_endpoint"] = cfg.KiroPreferredEndpoint
+ }
+ if refreshToken != "" {
+ attrs["refresh_token"] = refreshToken
+ }
+ proxyURL := strings.TrimSpace(kk.ProxyURL)
+ a := &coreauth.Auth{
+ ID: id,
+ Provider: "kiro",
+ Label: "kiro-token",
+ Status: coreauth.StatusActive,
+ ProxyURL: proxyURL,
+ Attributes: attrs,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+
+ if refreshToken != "" {
+ if a.Metadata == nil {
+ a.Metadata = make(map[string]any)
+ }
+ a.Metadata["refresh_token"] = refreshToken
+ }
+
+ out = append(out, a)
+ }
+ return out
+}
diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go
index 9f370127..a451ef6e 100644
--- a/internal/watcher/watcher.go
+++ b/internal/watcher/watcher.go
@@ -146,3 +146,111 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
w.clientsMutex.RUnlock()
return snapshotCoreAuths(cfg, w.authDir)
}
+
+// NotifyTokenRefreshed 处理后台刷新器的 token 更新通知
+// 当后台刷新器成功刷新 token 后调用此方法,更新内存中的 Auth 对象
+// tokenID: token 文件名(如 kiro-xxx.json)
+// accessToken: 新的 access token
+// refreshToken: 新的 refresh token
+// expiresAt: 新的过期时间
+func (w *Watcher) NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt string) {
+ if w == nil {
+ return
+ }
+
+ w.clientsMutex.Lock()
+ defer w.clientsMutex.Unlock()
+
+ // 遍历 currentAuths,找到匹配的 Auth 并更新
+ updated := false
+ for id, auth := range w.currentAuths {
+ if auth == nil || auth.Metadata == nil {
+ continue
+ }
+
+ // 检查是否是 kiro 类型的 auth
+ authType, _ := auth.Metadata["type"].(string)
+ if authType != "kiro" {
+ continue
+ }
+
+ // 多种匹配方式,解决不同来源的 auth 对象字段差异
+ matched := false
+
+ // 1. 通过 auth.ID 匹配(ID 可能包含文件名)
+ if !matched && auth.ID != "" {
+ if auth.ID == tokenID || strings.HasSuffix(auth.ID, "/"+tokenID) || strings.HasSuffix(auth.ID, "\\"+tokenID) {
+ matched = true
+ }
+ // ID 可能是 "kiro-xxx" 格式(无扩展名),tokenID 是 "kiro-xxx.json"
+ if !matched && strings.TrimSuffix(tokenID, ".json") == auth.ID {
+ matched = true
+ }
+ }
+
+ // 2. 通过 auth.Attributes["path"] 匹配
+ if !matched && auth.Attributes != nil {
+ if authPath := auth.Attributes["path"]; authPath != "" {
+ // 提取文件名部分进行比较
+ pathBase := authPath
+ if idx := strings.LastIndexAny(authPath, "/\\"); idx >= 0 {
+ pathBase = authPath[idx+1:]
+ }
+ if pathBase == tokenID || strings.TrimSuffix(pathBase, ".json") == strings.TrimSuffix(tokenID, ".json") {
+ matched = true
+ }
+ }
+ }
+
+ // 3. 通过 auth.FileName 匹配(原有逻辑)
+ if !matched && auth.FileName != "" {
+ if auth.FileName == tokenID || strings.HasSuffix(auth.FileName, "/"+tokenID) || strings.HasSuffix(auth.FileName, "\\"+tokenID) {
+ matched = true
+ }
+ }
+
+ if matched {
+ // 更新内存中的 token
+ auth.Metadata["access_token"] = accessToken
+ auth.Metadata["refresh_token"] = refreshToken
+ auth.Metadata["expires_at"] = expiresAt
+ auth.Metadata["last_refresh"] = time.Now().Format(time.RFC3339)
+ auth.UpdatedAt = time.Now()
+ auth.LastRefreshedAt = time.Now()
+
+ log.Infof("watcher: updated in-memory auth for token %s (auth ID: %s)", tokenID, id)
+ updated = true
+
+ // 同时更新 runtimeAuths 中的副本(如果存在)
+ if w.runtimeAuths != nil {
+ if runtimeAuth, ok := w.runtimeAuths[id]; ok && runtimeAuth != nil {
+ if runtimeAuth.Metadata == nil {
+ runtimeAuth.Metadata = make(map[string]any)
+ }
+ runtimeAuth.Metadata["access_token"] = accessToken
+ runtimeAuth.Metadata["refresh_token"] = refreshToken
+ runtimeAuth.Metadata["expires_at"] = expiresAt
+ runtimeAuth.Metadata["last_refresh"] = time.Now().Format(time.RFC3339)
+ runtimeAuth.UpdatedAt = time.Now()
+ runtimeAuth.LastRefreshedAt = time.Now()
+ }
+ }
+
+ // 发送更新通知到 authQueue
+ if w.authQueue != nil {
+ go func(authClone *coreauth.Auth) {
+ update := AuthUpdate{
+ Action: AuthUpdateActionModify,
+ ID: authClone.ID,
+ Auth: authClone,
+ }
+ w.dispatchAuthUpdates([]AuthUpdate{update})
+ }(auth.Clone())
+ }
+ }
+ }
+
+ if !updated {
+ log.Debugf("watcher: no matching auth found for token %s, will be picked up on next file scan", tokenID)
+ }
+}
diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go
index 23ef6535..7f842bbb 100644
--- a/sdk/api/handlers/handlers.go
+++ b/sdk/api/handlers/handlers.go
@@ -267,10 +267,11 @@ type BaseAPIHandler struct {
// Returns:
// - *BaseAPIHandler: A new API handlers instance
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
- return &BaseAPIHandler{
+ h := &BaseAPIHandler{
Cfg: cfg,
AuthManager: authManager,
}
+ return h
}
// UpdateClients updates the handlers' client list and configuration.
diff --git a/sdk/api/handlers/openai/endpoint_compat.go b/sdk/api/handlers/openai/endpoint_compat.go
new file mode 100644
index 00000000..d7fc5f2f
--- /dev/null
+++ b/sdk/api/handlers/openai/endpoint_compat.go
@@ -0,0 +1,37 @@
+package openai
+
+import "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
+
+const (
+ openAIChatEndpoint = "/chat/completions"
+ openAIResponsesEndpoint = "/responses"
+)
+
+func resolveEndpointOverride(modelName, requestedEndpoint string) (string, bool) {
+ if modelName == "" {
+ return "", false
+ }
+ info := registry.GetGlobalRegistry().GetModelInfo(modelName, "")
+ if info == nil || len(info.SupportedEndpoints) == 0 {
+ return "", false
+ }
+ if endpointListContains(info.SupportedEndpoints, requestedEndpoint) {
+ return "", false
+ }
+ if requestedEndpoint == openAIChatEndpoint && endpointListContains(info.SupportedEndpoints, openAIResponsesEndpoint) {
+ return openAIResponsesEndpoint, true
+ }
+ if requestedEndpoint == openAIResponsesEndpoint && endpointListContains(info.SupportedEndpoints, openAIChatEndpoint) {
+ return openAIChatEndpoint, true
+ }
+ return "", false
+}
+
+func endpointListContains(items []string, value string) bool {
+ for _, item := range items {
+ if item == value {
+ return true
+ }
+ }
+ return false
+}
diff --git a/sdk/api/handlers/openai/openai_handlers.go b/sdk/api/handlers/openai/openai_handlers.go
index 9c161a1c..8f86f1b4 100644
--- a/sdk/api/handlers/openai/openai_handlers.go
+++ b/sdk/api/handlers/openai/openai_handlers.go
@@ -17,6 +17,7 @@ import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
+ codexconverter "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/openai/chat-completions"
responsesconverter "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/openai/responses"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/tidwall/gjson"
@@ -112,6 +113,23 @@ func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) {
streamResult := gjson.GetBytes(rawJSON, "stream")
stream := streamResult.Type == gjson.True
+ modelName := gjson.GetBytes(rawJSON, "model").String()
+ if overrideEndpoint, ok := resolveEndpointOverride(modelName, openAIChatEndpoint); ok && overrideEndpoint == openAIResponsesEndpoint {
+ originalChat := rawJSON
+ if shouldTreatAsResponsesFormat(rawJSON) {
+ // Already responses-style payload; no conversion needed.
+ } else {
+ rawJSON = codexconverter.ConvertOpenAIRequestToCodex(modelName, rawJSON, stream)
+ }
+ stream = gjson.GetBytes(rawJSON, "stream").Bool()
+ if stream {
+ h.handleStreamingResponseViaResponses(c, rawJSON, originalChat)
+ } else {
+ h.handleNonStreamingResponseViaResponses(c, rawJSON, originalChat)
+ }
+ return
+ }
+
// Some clients send OpenAI Responses-format payloads to /v1/chat/completions.
// Convert them to Chat Completions so downstream translators preserve tool metadata.
if shouldTreatAsResponsesFormat(rawJSON) {
@@ -245,6 +263,76 @@ func convertCompletionsRequestToChatCompletions(rawJSON []byte) []byte {
return []byte(out)
}
+func convertResponsesObjectToChatCompletion(ctx context.Context, modelName string, originalChatJSON, responsesRequestJSON, responsesPayload []byte) []byte {
+ if len(responsesPayload) == 0 {
+ return nil
+ }
+ wrapped := wrapResponsesPayloadAsCompleted(responsesPayload)
+ if len(wrapped) == 0 {
+ return nil
+ }
+ var param any
+ converted := codexconverter.ConvertCodexResponseToOpenAINonStream(ctx, modelName, originalChatJSON, responsesRequestJSON, wrapped, ¶m)
+ if converted == "" {
+ return nil
+ }
+ return []byte(converted)
+}
+
+func wrapResponsesPayloadAsCompleted(payload []byte) []byte {
+ if gjson.GetBytes(payload, "type").Exists() {
+ return payload
+ }
+ if gjson.GetBytes(payload, "object").String() != "response" {
+ return payload
+ }
+ wrapped := `{"type":"response.completed","response":{}}`
+ wrapped, _ = sjson.SetRaw(wrapped, "response", string(payload))
+ return []byte(wrapped)
+}
+
+func writeConvertedResponsesChunk(c *gin.Context, ctx context.Context, modelName string, originalChatJSON, responsesRequestJSON, chunk []byte, param *any) {
+ outputs := codexconverter.ConvertCodexResponseToOpenAI(ctx, modelName, originalChatJSON, responsesRequestJSON, chunk, param)
+ for _, out := range outputs {
+ if out == "" {
+ continue
+ }
+ _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", out)
+ }
+}
+
+func (h *OpenAIAPIHandler) forwardResponsesAsChatStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, ctx context.Context, modelName string, originalChatJSON, responsesRequestJSON []byte, param *any) {
+ h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
+ WriteChunk: func(chunk []byte) {
+ outputs := codexconverter.ConvertCodexResponseToOpenAI(ctx, modelName, originalChatJSON, responsesRequestJSON, chunk, param)
+ for _, out := range outputs {
+ if out == "" {
+ continue
+ }
+ _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", out)
+ }
+ },
+ WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
+ if errMsg == nil {
+ return
+ }
+ status := http.StatusInternalServerError
+ if errMsg.StatusCode > 0 {
+ status = errMsg.StatusCode
+ }
+ errText := http.StatusText(status)
+ if errMsg.Error != nil && errMsg.Error.Error() != "" {
+ errText = errMsg.Error.Error()
+ }
+ body := handlers.BuildErrorResponseBody(status, errText)
+ _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(body))
+ },
+ WriteDone: func() {
+ _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
+ },
+ })
+}
+
// convertChatCompletionsResponseToCompletions converts chat completions API response back to completions format.
// This ensures the completions endpoint returns data in the expected format.
//
@@ -441,6 +529,30 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []
cliCancel()
}
+func (h *OpenAIAPIHandler) handleNonStreamingResponseViaResponses(c *gin.Context, rawJSON []byte, originalChatJSON []byte) {
+ c.Header("Content-Type", "application/json")
+
+ modelName := gjson.GetBytes(rawJSON, "model").String()
+ cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
+ resp, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
+ if errMsg != nil {
+ h.WriteErrorResponse(c, errMsg)
+ cliCancel(errMsg.Error)
+ return
+ }
+ converted := convertResponsesObjectToChatCompletion(cliCtx, modelName, originalChatJSON, rawJSON, resp)
+ if converted == nil {
+ h.WriteErrorResponse(c, &interfaces.ErrorMessage{
+ StatusCode: http.StatusInternalServerError,
+ Error: fmt.Errorf("failed to convert response to chat completion format"),
+ })
+ cliCancel(fmt.Errorf("response conversion failed"))
+ return
+ }
+ _, _ = c.Writer.Write(converted)
+ cliCancel()
+}
+
// handleStreamingResponse handles streaming responses for Gemini models.
// It establishes a streaming connection with the backend service and forwards
// the response chunks to the client in real-time using Server-Sent Events.
@@ -515,6 +627,67 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
}
}
+func (h *OpenAIAPIHandler) handleStreamingResponseViaResponses(c *gin.Context, rawJSON []byte, originalChatJSON []byte) {
+ flusher, ok := c.Writer.(http.Flusher)
+ if !ok {
+ c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
+ Error: handlers.ErrorDetail{
+ Message: "Streaming not supported",
+ Type: "server_error",
+ },
+ })
+ return
+ }
+
+ modelName := gjson.GetBytes(rawJSON, "model").String()
+ cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
+ dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenaiResponse, modelName, rawJSON, h.GetAlt(c))
+ var param any
+
+ setSSEHeaders := func() {
+ c.Header("Content-Type", "text/event-stream")
+ c.Header("Cache-Control", "no-cache")
+ c.Header("Connection", "keep-alive")
+ c.Header("Access-Control-Allow-Origin", "*")
+ }
+
+ // Peek for first usable chunk
+ for {
+ select {
+ case <-c.Request.Context().Done():
+ cliCancel(c.Request.Context().Err())
+ return
+ case errMsg, ok := <-errChan:
+ if !ok {
+ errChan = nil
+ continue
+ }
+ h.WriteErrorResponse(c, errMsg)
+ if errMsg != nil {
+ cliCancel(errMsg.Error)
+ } else {
+ cliCancel(nil)
+ }
+ return
+ case chunk, ok := <-dataChan:
+ if !ok {
+ setSSEHeaders()
+ _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
+ flusher.Flush()
+ cliCancel(nil)
+ return
+ }
+
+ setSSEHeaders()
+ writeConvertedResponsesChunk(c, cliCtx, modelName, originalChatJSON, rawJSON, chunk, ¶m)
+ flusher.Flush()
+
+ h.forwardResponsesAsChatStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, cliCtx, modelName, originalChatJSON, rawJSON, ¶m)
+ return
+ }
+ }
+}
+
// handleCompletionsNonStreamingResponse handles non-streaming completions responses.
// It converts completions request to chat completions format, sends to backend,
// then converts the response back to completions format before sending to client.
diff --git a/sdk/api/handlers/openai/openai_responses_handlers.go b/sdk/api/handlers/openai/openai_responses_handlers.go
index 4b611af3..e18789e1 100644
--- a/sdk/api/handlers/openai/openai_responses_handlers.go
+++ b/sdk/api/handlers/openai/openai_responses_handlers.go
@@ -16,6 +16,7 @@ import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
+ responsesconverter "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/openai/responses"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -84,7 +85,21 @@ func (h *OpenAIResponsesAPIHandler) Responses(c *gin.Context) {
// Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJSON, "stream")
- if streamResult.Type == gjson.True {
+ stream := streamResult.Type == gjson.True
+
+ modelName := gjson.GetBytes(rawJSON, "model").String()
+ if overrideEndpoint, ok := resolveEndpointOverride(modelName, openAIResponsesEndpoint); ok && overrideEndpoint == openAIChatEndpoint {
+ chatJSON := responsesconverter.ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName, rawJSON, stream)
+ stream = gjson.GetBytes(chatJSON, "stream").Bool()
+ if stream {
+ h.handleStreamingResponseViaChat(c, rawJSON, chatJSON)
+ } else {
+ h.handleNonStreamingResponseViaChat(c, rawJSON, chatJSON)
+ }
+ return
+ }
+
+ if stream {
h.handleStreamingResponse(c, rawJSON)
} else {
h.handleNonStreamingResponse(c, rawJSON)
@@ -160,6 +175,31 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
cliCancel()
}
+func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponseViaChat(c *gin.Context, originalResponsesJSON, chatJSON []byte) {
+ c.Header("Content-Type", "application/json")
+
+ modelName := gjson.GetBytes(chatJSON, "model").String()
+ cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
+ resp, errMsg := h.ExecuteWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
+ if errMsg != nil {
+ h.WriteErrorResponse(c, errMsg)
+ cliCancel(errMsg.Error)
+ return
+ }
+ var param any
+ converted := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(cliCtx, modelName, originalResponsesJSON, originalResponsesJSON, resp, ¶m)
+ if converted == "" {
+ h.WriteErrorResponse(c, &interfaces.ErrorMessage{
+ StatusCode: http.StatusInternalServerError,
+ Error: fmt.Errorf("failed to convert chat completion response to responses format"),
+ })
+ cliCancel(fmt.Errorf("response conversion failed"))
+ return
+ }
+ _, _ = c.Writer.Write([]byte(converted))
+ cliCancel()
+}
+
// handleStreamingResponse handles streaming responses for Gemini models.
// It establishes a streaming connection with the backend service and forwards
// the response chunks to the client in real-time using Server-Sent Events.
@@ -240,6 +280,116 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
}
}
+func (h *OpenAIResponsesAPIHandler) handleStreamingResponseViaChat(c *gin.Context, originalResponsesJSON, chatJSON []byte) {
+ flusher, ok := c.Writer.(http.Flusher)
+ if !ok {
+ c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
+ Error: handlers.ErrorDetail{
+ Message: "Streaming not supported",
+ Type: "server_error",
+ },
+ })
+ return
+ }
+
+ modelName := gjson.GetBytes(chatJSON, "model").String()
+ cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
+ dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, OpenAI, modelName, chatJSON, "")
+ var param any
+
+ setSSEHeaders := func() {
+ c.Header("Content-Type", "text/event-stream")
+ c.Header("Cache-Control", "no-cache")
+ c.Header("Connection", "keep-alive")
+ c.Header("Access-Control-Allow-Origin", "*")
+ }
+
+ for {
+ select {
+ case <-c.Request.Context().Done():
+ cliCancel(c.Request.Context().Err())
+ return
+ case errMsg, ok := <-errChan:
+ if !ok {
+ errChan = nil
+ continue
+ }
+ h.WriteErrorResponse(c, errMsg)
+ if errMsg != nil {
+ cliCancel(errMsg.Error)
+ } else {
+ cliCancel(nil)
+ }
+ return
+ case chunk, ok := <-dataChan:
+ if !ok {
+ setSSEHeaders()
+ _, _ = c.Writer.Write([]byte("\n"))
+ flusher.Flush()
+ cliCancel(nil)
+ return
+ }
+
+ setSSEHeaders()
+ writeChatAsResponsesChunk(c, cliCtx, modelName, originalResponsesJSON, chunk, ¶m)
+ flusher.Flush()
+
+ h.forwardChatAsResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, cliCtx, modelName, originalResponsesJSON, ¶m)
+ return
+ }
+ }
+}
+
+func writeChatAsResponsesChunk(c *gin.Context, ctx context.Context, modelName string, originalResponsesJSON, chunk []byte, param *any) {
+ outputs := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx, modelName, originalResponsesJSON, originalResponsesJSON, chunk, param)
+ for _, out := range outputs {
+ if out == "" {
+ continue
+ }
+ if bytes.HasPrefix([]byte(out), []byte("event:")) {
+ _, _ = c.Writer.Write([]byte("\n"))
+ }
+ _, _ = c.Writer.Write([]byte(out))
+ _, _ = c.Writer.Write([]byte("\n"))
+ }
+}
+
+func (h *OpenAIResponsesAPIHandler) forwardChatAsResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, ctx context.Context, modelName string, originalResponsesJSON []byte, param *any) {
+ h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
+ WriteChunk: func(chunk []byte) {
+ outputs := responsesconverter.ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx, modelName, originalResponsesJSON, originalResponsesJSON, chunk, param)
+ for _, out := range outputs {
+ if out == "" {
+ continue
+ }
+ if bytes.HasPrefix([]byte(out), []byte("event:")) {
+ _, _ = c.Writer.Write([]byte("\n"))
+ }
+ _, _ = c.Writer.Write([]byte(out))
+ _, _ = c.Writer.Write([]byte("\n"))
+ }
+ },
+ WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
+ if errMsg == nil {
+ return
+ }
+ status := http.StatusInternalServerError
+ if errMsg.StatusCode > 0 {
+ status = errMsg.StatusCode
+ }
+ errText := http.StatusText(status)
+ if errMsg.Error != nil && errMsg.Error.Error() != "" {
+ errText = errMsg.Error.Error()
+ }
+ body := handlers.BuildErrorResponseBody(status, errText)
+ _, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
+ },
+ WriteDone: func() {
+ _, _ = c.Writer.Write([]byte("\n"))
+ },
+ })
+}
+
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
WriteChunk: func(chunk []byte) {
diff --git a/sdk/auth/filestore.go b/sdk/auth/filestore.go
index 795bba0d..4715d7f7 100644
--- a/sdk/auth/filestore.go
+++ b/sdk/auth/filestore.go
@@ -228,6 +228,15 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
if disabled {
status = cliproxyauth.StatusDisabled
}
+
+ // Calculate NextRefreshAfter from expires_at (20 minutes before expiry)
+ var nextRefreshAfter time.Time
+ if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" {
+ if expiresAt, err := time.Parse(time.RFC3339, expiresAtStr); err == nil {
+ nextRefreshAfter = expiresAt.Add(-20 * time.Minute)
+ }
+ }
+
auth := &cliproxyauth.Auth{
ID: id,
Provider: provider,
@@ -240,7 +249,7 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
CreatedAt: info.ModTime(),
UpdatedAt: info.ModTime(),
LastRefreshedAt: time.Time{},
- NextRefreshAfter: time.Time{},
+ NextRefreshAfter: nextRefreshAfter,
}
if email, ok := metadata["email"].(string); ok && email != "" {
auth.Attributes["email"] = email
diff --git a/sdk/auth/github_copilot.go b/sdk/auth/github_copilot.go
new file mode 100644
index 00000000..1d14ac47
--- /dev/null
+++ b/sdk/auth/github_copilot.go
@@ -0,0 +1,129 @@
+package auth
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
+ log "github.com/sirupsen/logrus"
+)
+
+// GitHubCopilotAuthenticator implements the OAuth device flow login for GitHub Copilot.
+type GitHubCopilotAuthenticator struct{}
+
+// NewGitHubCopilotAuthenticator constructs a new GitHub Copilot authenticator.
+func NewGitHubCopilotAuthenticator() Authenticator {
+ return &GitHubCopilotAuthenticator{}
+}
+
+// Provider returns the provider key for github-copilot.
+func (GitHubCopilotAuthenticator) Provider() string {
+ return "github-copilot"
+}
+
+// RefreshLead returns nil since GitHub OAuth tokens don't expire in the traditional sense.
+// The token remains valid until the user revokes it or the Copilot subscription expires.
+func (GitHubCopilotAuthenticator) RefreshLead() *time.Duration {
+ return nil
+}
+
+// Login initiates the GitHub device flow authentication for Copilot access.
+func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
+ if cfg == nil {
+ return nil, fmt.Errorf("cliproxy auth: configuration is required")
+ }
+ if opts == nil {
+ opts = &LoginOptions{}
+ }
+
+ authSvc := copilot.NewCopilotAuth(cfg)
+
+ // Start the device flow
+ fmt.Println("Starting GitHub Copilot authentication...")
+ deviceCode, err := authSvc.StartDeviceFlow(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("github-copilot: failed to start device flow: %w", err)
+ }
+
+ // Display the user code and verification URL
+ fmt.Printf("\nTo authenticate, please visit: %s\n", deviceCode.VerificationURI)
+ fmt.Printf("And enter the code: %s\n\n", deviceCode.UserCode)
+
+ // Try to open the browser automatically
+ if !opts.NoBrowser {
+ if browser.IsAvailable() {
+ if errOpen := browser.OpenURL(deviceCode.VerificationURI); errOpen != nil {
+ log.Warnf("Failed to open browser automatically: %v", errOpen)
+ }
+ }
+ }
+
+ fmt.Println("Waiting for GitHub authorization...")
+ fmt.Printf("(This will timeout in %d seconds if not authorized)\n", deviceCode.ExpiresIn)
+
+ // Wait for user authorization
+ authBundle, err := authSvc.WaitForAuthorization(ctx, deviceCode)
+ if err != nil {
+ errMsg := copilot.GetUserFriendlyMessage(err)
+ return nil, fmt.Errorf("github-copilot: %s", errMsg)
+ }
+
+ // Verify the token can get a Copilot API token
+ fmt.Println("Verifying Copilot access...")
+ apiToken, err := authSvc.GetCopilotAPIToken(ctx, authBundle.TokenData.AccessToken)
+ if err != nil {
+ return nil, fmt.Errorf("github-copilot: failed to verify Copilot access - you may not have an active Copilot subscription: %w", err)
+ }
+
+ // Create the token storage
+ tokenStorage := authSvc.CreateTokenStorage(authBundle)
+
+ // Build metadata with token information for the executor
+ metadata := map[string]any{
+ "type": "github-copilot",
+ "username": authBundle.Username,
+ "access_token": authBundle.TokenData.AccessToken,
+ "token_type": authBundle.TokenData.TokenType,
+ "scope": authBundle.TokenData.Scope,
+ "timestamp": time.Now().UnixMilli(),
+ }
+
+ if apiToken.ExpiresAt > 0 {
+ metadata["api_token_expires_at"] = apiToken.ExpiresAt
+ }
+
+ fileName := fmt.Sprintf("github-copilot-%s.json", authBundle.Username)
+
+ fmt.Printf("\nGitHub Copilot authentication successful for user: %s\n", authBundle.Username)
+
+ return &coreauth.Auth{
+ ID: fileName,
+ Provider: a.Provider(),
+ FileName: fileName,
+ Label: authBundle.Username,
+ Storage: tokenStorage,
+ Metadata: metadata,
+ }, nil
+}
+
+// RefreshGitHubCopilotToken validates and returns the current token status.
+// GitHub OAuth tokens don't need traditional refresh - we just validate they still work.
+func RefreshGitHubCopilotToken(ctx context.Context, cfg *config.Config, storage *copilot.CopilotTokenStorage) error {
+ if storage == nil || storage.AccessToken == "" {
+ return fmt.Errorf("no token available")
+ }
+
+ authSvc := copilot.NewCopilotAuth(cfg)
+
+ // Validate the token can still get a Copilot API token
+ _, err := authSvc.GetCopilotAPIToken(ctx, storage.AccessToken)
+ if err != nil {
+ return fmt.Errorf("token validation failed: %w", err)
+ }
+
+ return nil
+}
diff --git a/sdk/auth/kilo.go b/sdk/auth/kilo.go
new file mode 100644
index 00000000..7e98f7c4
--- /dev/null
+++ b/sdk/auth/kilo.go
@@ -0,0 +1,121 @@
+package auth
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kilo"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
+)
+
+// KiloAuthenticator implements the login flow for Kilo AI accounts.
+type KiloAuthenticator struct{}
+
+// NewKiloAuthenticator constructs a Kilo authenticator.
+func NewKiloAuthenticator() *KiloAuthenticator {
+ return &KiloAuthenticator{}
+}
+
+func (a *KiloAuthenticator) Provider() string {
+ return "kilo"
+}
+
+func (a *KiloAuthenticator) RefreshLead() *time.Duration {
+ return nil
+}
+
+// Login manages the device flow authentication for Kilo AI.
+func (a *KiloAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
+ if cfg == nil {
+ return nil, fmt.Errorf("cliproxy auth: configuration is required")
+ }
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ if opts == nil {
+ opts = &LoginOptions{}
+ }
+
+ kilocodeAuth := kilo.NewKiloAuth()
+
+ fmt.Println("Initiating Kilo device authentication...")
+ resp, err := kilocodeAuth.InitiateDeviceFlow(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("failed to initiate device flow: %w", err)
+ }
+
+ fmt.Printf("Please visit: %s\n", resp.VerificationURL)
+ fmt.Printf("And enter code: %s\n", resp.Code)
+
+ fmt.Println("Waiting for authorization...")
+ status, err := kilocodeAuth.PollForToken(ctx, resp.Code)
+ if err != nil {
+ return nil, fmt.Errorf("authentication failed: %w", err)
+ }
+
+ fmt.Printf("Authentication successful for %s\n", status.UserEmail)
+
+ profile, err := kilocodeAuth.GetProfile(ctx, status.Token)
+ if err != nil {
+ return nil, fmt.Errorf("failed to fetch profile: %w", err)
+ }
+
+ var orgID string
+ if len(profile.Orgs) > 1 {
+ fmt.Println("Multiple organizations found. Please select one:")
+ for i, org := range profile.Orgs {
+ fmt.Printf("[%d] %s (%s)\n", i+1, org.Name, org.ID)
+ }
+
+ if opts.Prompt != nil {
+ input, err := opts.Prompt("Enter the number of the organization: ")
+ if err != nil {
+ return nil, err
+ }
+ var choice int
+ _, err = fmt.Sscan(input, &choice)
+ if err == nil && choice > 0 && choice <= len(profile.Orgs) {
+ orgID = profile.Orgs[choice-1].ID
+ } else {
+ orgID = profile.Orgs[0].ID
+ fmt.Printf("Invalid choice, defaulting to %s\n", profile.Orgs[0].Name)
+ }
+ } else {
+ orgID = profile.Orgs[0].ID
+ fmt.Printf("Non-interactive mode, defaulting to organization: %s\n", profile.Orgs[0].Name)
+ }
+ } else if len(profile.Orgs) == 1 {
+ orgID = profile.Orgs[0].ID
+ }
+
+ defaults, err := kilocodeAuth.GetDefaults(ctx, status.Token, orgID)
+ if err != nil {
+ fmt.Printf("Warning: failed to fetch defaults: %v\n", err)
+ defaults = &kilo.Defaults{}
+ }
+
+ ts := &kilo.KiloTokenStorage{
+ Token: status.Token,
+ OrganizationID: orgID,
+ Model: defaults.Model,
+ Email: status.UserEmail,
+ Type: "kilo",
+ }
+
+ fileName := kilo.CredentialFileName(status.UserEmail)
+ metadata := map[string]any{
+ "email": status.UserEmail,
+ "organization_id": orgID,
+ "model": defaults.Model,
+ }
+
+ return &coreauth.Auth{
+ ID: fileName,
+ Provider: a.Provider(),
+ FileName: fileName,
+ Storage: ts,
+ Metadata: metadata,
+ }, nil
+}
diff --git a/sdk/auth/kiro.go b/sdk/auth/kiro.go
new file mode 100644
index 00000000..ad165b75
--- /dev/null
+++ b/sdk/auth/kiro.go
@@ -0,0 +1,446 @@
+package auth
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
+)
+
+// extractKiroIdentifier extracts a meaningful identifier for file naming.
+// Returns account name if provided, otherwise profile ARN ID, then client ID.
+// All extracted values are sanitized to prevent path injection attacks.
+func extractKiroIdentifier(accountName, profileArn, clientID string) string {
+ // Priority 1: Use account name if provided
+ if accountName != "" {
+ return kiroauth.SanitizeEmailForFilename(accountName)
+ }
+
+ // Priority 2: Use profile ARN ID part (sanitized to prevent path injection)
+ if profileArn != "" {
+ parts := strings.Split(profileArn, "/")
+ if len(parts) >= 2 {
+ // Sanitize the ARN component to prevent path traversal
+ return kiroauth.SanitizeEmailForFilename(parts[len(parts)-1])
+ }
+ }
+
+ // Priority 3: Use client ID (for IDC auth without email/profileArn)
+ if clientID != "" {
+ return kiroauth.SanitizeEmailForFilename(clientID)
+ }
+
+ // Fallback: timestamp
+ return fmt.Sprintf("%d", time.Now().UnixNano()%100000)
+}
+
+// KiroAuthenticator implements OAuth authentication for Kiro with Google login.
+type KiroAuthenticator struct{}
+
+// NewKiroAuthenticator constructs a Kiro authenticator.
+func NewKiroAuthenticator() *KiroAuthenticator {
+ return &KiroAuthenticator{}
+}
+
+// Provider returns the provider key for the authenticator.
+func (a *KiroAuthenticator) Provider() string {
+ return "kiro"
+}
+
+// RefreshLead indicates how soon before expiry a refresh should be attempted.
+// Set to 20 minutes for proactive refresh before token expiry.
+func (a *KiroAuthenticator) RefreshLead() *time.Duration {
+ d := 20 * time.Minute
+ return &d
+}
+
+// createAuthRecord creates an auth record from token data.
+func (a *KiroAuthenticator) createAuthRecord(tokenData *kiroauth.KiroTokenData, source string) (*coreauth.Auth, error) {
+ // Parse expires_at
+ expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
+ if err != nil {
+ expiresAt = time.Now().Add(1 * time.Hour)
+ }
+
+ // Determine label and identifier based on auth method
+ // Generate sequence number for uniqueness
+ seq := time.Now().UnixNano() % 100000
+
+ var label, idPart string
+ if tokenData.AuthMethod == "idc" {
+ label = "kiro-idc"
+ // Priority: email > startUrl identifier > sequence only
+ // Email is unique, so no sequence needed when email is available
+ if tokenData.Email != "" {
+ idPart = kiroauth.SanitizeEmailForFilename(tokenData.Email)
+ } else if tokenData.StartURL != "" {
+ identifier := kiroauth.ExtractIDCIdentifier(tokenData.StartURL)
+ if identifier != "" {
+ idPart = fmt.Sprintf("%s-%05d", identifier, seq)
+ } else {
+ idPart = fmt.Sprintf("%05d", seq)
+ }
+ } else {
+ idPart = fmt.Sprintf("%05d", seq)
+ }
+ } else {
+ label = fmt.Sprintf("kiro-%s", source)
+ idPart = extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn, tokenData.ClientID)
+ }
+
+ now := time.Now()
+ fileName := fmt.Sprintf("%s-%s.json", label, idPart)
+
+ metadata := map[string]any{
+ "type": "kiro",
+ "access_token": tokenData.AccessToken,
+ "refresh_token": tokenData.RefreshToken,
+ "profile_arn": tokenData.ProfileArn,
+ "expires_at": tokenData.ExpiresAt,
+ "auth_method": tokenData.AuthMethod,
+ "provider": tokenData.Provider,
+ "client_id": tokenData.ClientID,
+ "client_secret": tokenData.ClientSecret,
+ "email": tokenData.Email,
+ }
+
+ // Add IDC-specific fields if present
+ if tokenData.StartURL != "" {
+ metadata["start_url"] = tokenData.StartURL
+ }
+ if tokenData.Region != "" {
+ metadata["region"] = tokenData.Region
+ }
+
+ attributes := map[string]string{
+ "profile_arn": tokenData.ProfileArn,
+ "source": source,
+ "email": tokenData.Email,
+ }
+
+ // Add IDC-specific attributes if present
+ if tokenData.AuthMethod == "idc" {
+ attributes["source"] = "aws-idc"
+ if tokenData.StartURL != "" {
+ attributes["start_url"] = tokenData.StartURL
+ }
+ if tokenData.Region != "" {
+ attributes["region"] = tokenData.Region
+ }
+ }
+
+ record := &coreauth.Auth{
+ ID: fileName,
+ Provider: "kiro",
+ FileName: fileName,
+ Label: label,
+ Status: coreauth.StatusActive,
+ CreatedAt: now,
+ UpdatedAt: now,
+ Metadata: metadata,
+ Attributes: attributes,
+ // NextRefreshAfter: 20 minutes before expiry
+ NextRefreshAfter: expiresAt.Add(-20 * time.Minute),
+ }
+
+ if tokenData.Email != "" {
+ fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email)
+ } else {
+ fmt.Println("\n✓ Kiro authentication completed successfully!")
+ }
+
+ return record, nil
+}
+
+// Login performs OAuth login for Kiro with AWS (Builder ID or IDC).
+// This shows a method selection prompt and handles both flows.
+func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
+ if cfg == nil {
+ return nil, fmt.Errorf("kiro auth: configuration is required")
+ }
+
+ // Use the unified method selection flow (Builder ID or IDC)
+ ssoClient := kiroauth.NewSSOOIDCClient(cfg)
+ tokenData, err := ssoClient.LoginWithMethodSelection(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("login failed: %w", err)
+ }
+
+ return a.createAuthRecord(tokenData, "aws")
+}
+
+// LoginWithAuthCode performs OAuth login for Kiro with AWS Builder ID using authorization code flow.
+// This provides a better UX than device code flow as it uses automatic browser callback.
+func (a *KiroAuthenticator) LoginWithAuthCode(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
+ if cfg == nil {
+ return nil, fmt.Errorf("kiro auth: configuration is required")
+ }
+
+ oauth := kiroauth.NewKiroOAuth(cfg)
+
+ // Use AWS Builder ID authorization code flow
+ tokenData, err := oauth.LoginWithBuilderIDAuthCode(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("login failed: %w", err)
+ }
+
+ // Parse expires_at
+ expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
+ if err != nil {
+ expiresAt = time.Now().Add(1 * time.Hour)
+ }
+
+ // Extract identifier for file naming
+ idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn, tokenData.ClientID)
+
+ now := time.Now()
+ fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
+
+ record := &coreauth.Auth{
+ ID: fileName,
+ Provider: "kiro",
+ FileName: fileName,
+ Label: "kiro-aws",
+ Status: coreauth.StatusActive,
+ CreatedAt: now,
+ UpdatedAt: now,
+ Metadata: map[string]any{
+ "type": "kiro",
+ "access_token": tokenData.AccessToken,
+ "refresh_token": tokenData.RefreshToken,
+ "profile_arn": tokenData.ProfileArn,
+ "expires_at": tokenData.ExpiresAt,
+ "auth_method": tokenData.AuthMethod,
+ "provider": tokenData.Provider,
+ "client_id": tokenData.ClientID,
+ "client_secret": tokenData.ClientSecret,
+ "email": tokenData.Email,
+ },
+ Attributes: map[string]string{
+ "profile_arn": tokenData.ProfileArn,
+ "source": "aws-builder-id-authcode",
+ "email": tokenData.Email,
+ },
+ // NextRefreshAfter: 20 minutes before expiry
+ NextRefreshAfter: expiresAt.Add(-20 * time.Minute),
+ }
+
+ if tokenData.Email != "" {
+ fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email)
+ } else {
+ fmt.Println("\n✓ Kiro authentication completed successfully!")
+ }
+
+ return record, nil
+}
+
+// LoginWithGoogle performs OAuth login for Kiro with Google.
+// NOTE: Google login is not available for third-party applications due to AWS Cognito restrictions.
+// Please use AWS Builder ID or import your token from Kiro IDE.
+func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
+ return nil, fmt.Errorf("Google login is not available for third-party applications due to AWS Cognito restrictions.\n\nAlternatives:\n 1. Use AWS Builder ID: cliproxy kiro --builder-id\n 2. Import token from Kiro IDE: cliproxy kiro --import\n\nTo get a token from Kiro IDE:\n 1. Open Kiro IDE and login with Google\n 2. Find: ~/.kiro/kiro-auth-token.json\n 3. Run: cliproxy kiro --import")
+}
+
+// LoginWithGitHub performs OAuth login for Kiro with GitHub.
+// NOTE: GitHub login is not available for third-party applications due to AWS Cognito restrictions.
+// Please use AWS Builder ID or import your token from Kiro IDE.
+func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
+ return nil, fmt.Errorf("GitHub login is not available for third-party applications due to AWS Cognito restrictions.\n\nAlternatives:\n 1. Use AWS Builder ID: cliproxy kiro --builder-id\n 2. Import token from Kiro IDE: cliproxy kiro --import\n\nTo get a token from Kiro IDE:\n 1. Open Kiro IDE and login with GitHub\n 2. Find: ~/.kiro/kiro-auth-token.json\n 3. Run: cliproxy kiro --import")
+}
+
+// ImportFromKiroIDE imports token from Kiro IDE's token file.
+func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.Config) (*coreauth.Auth, error) {
+ tokenData, err := kiroauth.LoadKiroIDEToken()
+ if err != nil {
+ return nil, fmt.Errorf("failed to load Kiro IDE token: %w", err)
+ }
+
+ // Parse expires_at
+ expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
+ if err != nil {
+ expiresAt = time.Now().Add(1 * time.Hour)
+ }
+
+ // Extract email from JWT if not already set (for imported tokens)
+ if tokenData.Email == "" {
+ tokenData.Email = kiroauth.ExtractEmailFromJWT(tokenData.AccessToken)
+ }
+
+ // Extract identifier for file naming
+ idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn, tokenData.ClientID)
+ // Sanitize provider to prevent path traversal (defense-in-depth)
+ provider := kiroauth.SanitizeEmailForFilename(strings.ToLower(strings.TrimSpace(tokenData.Provider)))
+ if provider == "" {
+ provider = "imported" // Fallback for legacy tokens without provider
+ }
+
+ now := time.Now()
+ fileName := fmt.Sprintf("kiro-%s-%s.json", provider, idPart)
+
+ record := &coreauth.Auth{
+ ID: fileName,
+ Provider: "kiro",
+ FileName: fileName,
+ Label: fmt.Sprintf("kiro-%s", provider),
+ Status: coreauth.StatusActive,
+ CreatedAt: now,
+ UpdatedAt: now,
+ Metadata: map[string]any{
+ "type": "kiro",
+ "access_token": tokenData.AccessToken,
+ "refresh_token": tokenData.RefreshToken,
+ "profile_arn": tokenData.ProfileArn,
+ "expires_at": tokenData.ExpiresAt,
+ "auth_method": tokenData.AuthMethod,
+ "provider": tokenData.Provider,
+ "client_id": tokenData.ClientID,
+ "client_secret": tokenData.ClientSecret,
+ "client_id_hash": tokenData.ClientIDHash,
+ "email": tokenData.Email,
+ "region": tokenData.Region,
+ "start_url": tokenData.StartURL,
+ },
+ Attributes: map[string]string{
+ "profile_arn": tokenData.ProfileArn,
+ "source": "kiro-ide-import",
+ "email": tokenData.Email,
+ "region": tokenData.Region,
+ },
+ // NextRefreshAfter: 20 minutes before expiry
+ NextRefreshAfter: expiresAt.Add(-20 * time.Minute),
+ }
+
+ // Display the email if extracted
+ if tokenData.Email != "" {
+ fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s, Account: %s)\n", tokenData.Provider, tokenData.Email)
+ } else {
+ fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s)\n", tokenData.Provider)
+ }
+
+ return record, nil
+}
+
+// Refresh refreshes an expired Kiro token using AWS SSO OIDC.
+func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, auth *coreauth.Auth) (*coreauth.Auth, error) {
+ if auth == nil || auth.Metadata == nil {
+ return nil, fmt.Errorf("invalid auth record")
+ }
+
+ refreshToken, ok := auth.Metadata["refresh_token"].(string)
+ if !ok || refreshToken == "" {
+ return nil, fmt.Errorf("refresh token not found")
+ }
+
+ clientID, _ := auth.Metadata["client_id"].(string)
+ clientSecret, _ := auth.Metadata["client_secret"].(string)
+ clientIDHash, _ := auth.Metadata["client_id_hash"].(string)
+ authMethod, _ := auth.Metadata["auth_method"].(string)
+ startURL, _ := auth.Metadata["start_url"].(string)
+ region, _ := auth.Metadata["region"].(string)
+
+ // For Enterprise Kiro IDE (IDC auth), try to load clientId/clientSecret from device registration
+ // if they are missing from metadata. This handles the case where token was imported without
+ // clientId/clientSecret but has clientIdHash.
+ if (clientID == "" || clientSecret == "") && clientIDHash != "" {
+ if loadedClientID, loadedClientSecret, err := loadDeviceRegistrationCredentials(clientIDHash); err == nil {
+ clientID = loadedClientID
+ clientSecret = loadedClientSecret
+ }
+ }
+
+ var tokenData *kiroauth.KiroTokenData
+ var err error
+
+ ssoClient := kiroauth.NewSSOOIDCClient(cfg)
+
+ // Use SSO OIDC refresh for AWS Builder ID or IDC, otherwise use Kiro's OAuth refresh endpoint
+ switch {
+ case clientID != "" && clientSecret != "" && authMethod == "idc" && region != "":
+ // IDC refresh with region-specific endpoint
+ tokenData, err = ssoClient.RefreshTokenWithRegion(ctx, clientID, clientSecret, refreshToken, region, startURL)
+ case clientID != "" && clientSecret != "" && (authMethod == "builder-id" || authMethod == "idc"):
+ // Builder ID or IDC refresh with default endpoint (us-east-1)
+ tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
+ default:
+ // Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub)
+ oauth := kiroauth.NewKiroOAuth(cfg)
+ tokenData, err = oauth.RefreshToken(ctx, refreshToken)
+ }
+
+ if err != nil {
+ return nil, fmt.Errorf("token refresh failed: %w", err)
+ }
+
+ // Parse expires_at
+ expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
+ if err != nil {
+ expiresAt = time.Now().Add(1 * time.Hour)
+ }
+
+ // Clone auth to avoid mutating the input parameter
+ updated := auth.Clone()
+ now := time.Now()
+ updated.UpdatedAt = now
+ updated.LastRefreshedAt = now
+ updated.Metadata["access_token"] = tokenData.AccessToken
+ updated.Metadata["refresh_token"] = tokenData.RefreshToken
+ updated.Metadata["expires_at"] = tokenData.ExpiresAt
+ updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization
+ // Store clientId/clientSecret if they were loaded from device registration
+ if clientID != "" && updated.Metadata["client_id"] == nil {
+ updated.Metadata["client_id"] = clientID
+ }
+ if clientSecret != "" && updated.Metadata["client_secret"] == nil {
+ updated.Metadata["client_secret"] = clientSecret
+ }
+ // NextRefreshAfter: 20 minutes before expiry
+ updated.NextRefreshAfter = expiresAt.Add(-20 * time.Minute)
+
+ return updated, nil
+}
+
+// loadDeviceRegistrationCredentials loads clientId and clientSecret from device registration file.
+// This is used when refreshing tokens that were imported without clientId/clientSecret.
+func loadDeviceRegistrationCredentials(clientIDHash string) (clientID, clientSecret string, err error) {
+ if clientIDHash == "" {
+ return "", "", fmt.Errorf("clientIdHash is empty")
+ }
+
+ // Sanitize clientIdHash to prevent path traversal
+ if strings.Contains(clientIDHash, "/") || strings.Contains(clientIDHash, "\\") || strings.Contains(clientIDHash, "..") {
+ return "", "", fmt.Errorf("invalid clientIdHash: contains path separator")
+ }
+
+ homeDir, err := os.UserHomeDir()
+ if err != nil {
+ return "", "", fmt.Errorf("failed to get home directory: %w", err)
+ }
+
+ deviceRegPath := filepath.Join(homeDir, ".aws", "sso", "cache", clientIDHash+".json")
+ data, err := os.ReadFile(deviceRegPath)
+ if err != nil {
+ return "", "", fmt.Errorf("failed to read device registration file: %w", err)
+ }
+
+ var deviceReg struct {
+ ClientID string `json:"clientId"`
+ ClientSecret string `json:"clientSecret"`
+ }
+
+ if err := json.Unmarshal(data, &deviceReg); err != nil {
+ return "", "", fmt.Errorf("failed to parse device registration: %w", err)
+ }
+
+ if deviceReg.ClientID == "" || deviceReg.ClientSecret == "" {
+ return "", "", fmt.Errorf("device registration missing clientId or clientSecret")
+ }
+
+ return deviceReg.ClientID, deviceReg.ClientSecret, nil
+}
diff --git a/sdk/auth/manager.go b/sdk/auth/manager.go
index c6469a7d..d630f128 100644
--- a/sdk/auth/manager.go
+++ b/sdk/auth/manager.go
@@ -74,3 +74,16 @@ func (m *Manager) Login(ctx context.Context, provider string, cfg *config.Config
}
return record, savedPath, nil
}
+
+// SaveAuth persists an auth record directly without going through the login flow.
+func (m *Manager) SaveAuth(record *coreauth.Auth, cfg *config.Config) (string, error) {
+ if m.store == nil {
+ return "", fmt.Errorf("no store configured")
+ }
+ if cfg != nil {
+ if dirSetter, ok := m.store.(interface{ SetBaseDir(string) }); ok {
+ dirSetter.SetBaseDir(cfg.AuthDir)
+ }
+ }
+ return m.store.Save(context.Background(), record)
+}
diff --git a/sdk/auth/refresh_registry.go b/sdk/auth/refresh_registry.go
index bf7f1448..ecf8e820 100644
--- a/sdk/auth/refresh_registry.go
+++ b/sdk/auth/refresh_registry.go
@@ -15,6 +15,8 @@ func init() {
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() })
+ registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() })
+ registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
}
func registerRefreshLead(provider string, factory func() Authenticator) {
diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go
index 76aae228..87592512 100644
--- a/sdk/cliproxy/auth/conductor.go
+++ b/sdk/cliproxy/auth/conductor.go
@@ -58,9 +58,9 @@ type RefreshEvaluator interface {
}
const (
- refreshCheckInterval = 5 * time.Second
+ refreshCheckInterval = 30 * time.Second
refreshPendingBackoff = time.Minute
- refreshFailureBackoff = 5 * time.Minute
+ refreshFailureBackoff = 1 * time.Minute
quotaBackoffBase = time.Second
quotaBackoffMax = 30 * time.Minute
)
@@ -2152,7 +2152,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
updated.Runtime = auth.Runtime
}
updated.LastRefreshedAt = now
- updated.NextRefreshAfter = time.Time{}
+ // Preserve NextRefreshAfter set by the Authenticator
+ // If the Authenticator set a reasonable refresh time, it should not be overwritten
+ // If the Authenticator did not set it (zero value), shouldRefresh will use default logic
updated.LastError = nil
updated.UpdatedAt = now
_, _ = m.Update(ctx, updated)
diff --git a/sdk/cliproxy/auth/oauth_model_alias.go b/sdk/cliproxy/auth/oauth_model_alias.go
index d5d2ff8a..8563aac4 100644
--- a/sdk/cliproxy/auth/oauth_model_alias.go
+++ b/sdk/cliproxy/auth/oauth_model_alias.go
@@ -221,7 +221,7 @@ func modelAliasChannel(auth *Auth) string {
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
// OAuth model alias (e.g., API key authentication).
//
-// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kimi.
+// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kiro, github-copilot, kimi.
func OAuthModelAliasChannel(provider, authKind string) string {
provider = strings.ToLower(strings.TrimSpace(provider))
authKind = strings.ToLower(strings.TrimSpace(authKind))
@@ -245,7 +245,7 @@ func OAuthModelAliasChannel(provider, authKind string) string {
return ""
}
return "codex"
- case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow", "kimi":
+ case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow", "kiro", "github-copilot", "kimi":
return provider
default:
return ""
diff --git a/sdk/cliproxy/auth/oauth_model_alias_test.go b/sdk/cliproxy/auth/oauth_model_alias_test.go
index 32390959..e12b6597 100644
--- a/sdk/cliproxy/auth/oauth_model_alias_test.go
+++ b/sdk/cliproxy/auth/oauth_model_alias_test.go
@@ -43,6 +43,15 @@ func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) {
input: "gemini-2.5-pro",
want: "gemini-2.5-pro-exp-03-25",
},
+ {
+ name: "kiro alias resolves",
+ aliases: map[string][]internalconfig.OAuthModelAlias{
+ "kiro": {{Name: "kiro-claude-sonnet-4-5", Alias: "sonnet"}},
+ },
+ channel: "kiro",
+ input: "sonnet",
+ want: "kiro-claude-sonnet-4-5",
+ },
{
name: "config suffix takes priority",
aliases: map[string][]internalconfig.OAuthModelAlias{
@@ -70,6 +79,24 @@ func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) {
input: "gemini-2.5-pro(none)",
want: "gemini-2.5-pro-exp-03-25(none)",
},
+ {
+ name: "github-copilot suffix preserved",
+ aliases: map[string][]internalconfig.OAuthModelAlias{
+ "github-copilot": {{Name: "claude-opus-4.6", Alias: "opus"}},
+ },
+ channel: "github-copilot",
+ input: "opus(medium)",
+ want: "claude-opus-4.6(medium)",
+ },
+ {
+ name: "github-copilot no suffix",
+ aliases: map[string][]internalconfig.OAuthModelAlias{
+ "github-copilot": {{Name: "claude-opus-4.6", Alias: "opus"}},
+ },
+ channel: "github-copilot",
+ input: "opus",
+ want: "claude-opus-4.6",
+ },
{
name: "kimi suffix preserved",
aliases: map[string][]internalconfig.OAuthModelAlias{
@@ -163,6 +190,10 @@ func createAuthForChannel(channel string) *Auth {
return &Auth{Provider: "iflow"}
case "kimi":
return &Auth{Provider: "kimi"}
+ case "kiro":
+ return &Auth{Provider: "kiro"}
+ case "github-copilot":
+ return &Auth{Provider: "github-copilot"}
default:
return &Auth{Provider: channel}
}
@@ -176,6 +207,22 @@ func TestOAuthModelAliasChannel_Kimi(t *testing.T) {
}
}
+func TestOAuthModelAliasChannel_GitHubCopilot(t *testing.T) {
+ t.Parallel()
+
+ if got := OAuthModelAliasChannel("github-copilot", ""); got != "github-copilot" {
+ t.Fatalf("OAuthModelAliasChannel() = %q, want %q", got, "github-copilot")
+ }
+}
+
+func TestOAuthModelAliasChannel_Kiro(t *testing.T) {
+ t.Parallel()
+
+ if got := OAuthModelAliasChannel("kiro", ""); got != "kiro" {
+ t.Fatalf("OAuthModelAliasChannel() = %q, want %q", got, "kiro")
+ }
+}
+
func TestApplyOAuthModelAlias_SuffixPreservation(t *testing.T) {
t.Parallel()
diff --git a/sdk/cliproxy/auth/types.go b/sdk/cliproxy/auth/types.go
index 96534bbe..88d0ea52 100644
--- a/sdk/cliproxy/auth/types.go
+++ b/sdk/cliproxy/auth/types.go
@@ -346,6 +346,18 @@ func (a *Auth) AccountInfo() (string, string) {
}
}
+ // For GitHub provider (including github-copilot), return username
+ if strings.HasPrefix(strings.ToLower(a.Provider), "github") {
+ if a.Metadata != nil {
+ if username, ok := a.Metadata["username"].(string); ok {
+ username = strings.TrimSpace(username)
+ if username != "" {
+ return "oauth", username
+ }
+ }
+ }
+ }
+
// Check metadata for email first (OAuth-style auth)
if a.Metadata != nil {
if v, ok := a.Metadata["email"].(string); ok {
diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go
index e89c49c0..2bd12d0a 100644
--- a/sdk/cliproxy/service.go
+++ b/sdk/cliproxy/service.go
@@ -13,6 +13,7 @@ import (
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
+ kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
@@ -100,6 +101,16 @@ func (s *Service) RegisterUsagePlugin(plugin usage.Plugin) {
usage.RegisterPlugin(plugin)
}
+// GetWatcher returns the underlying WatcherWrapper instance.
+// This allows external components (e.g., RefreshManager) to interact with the watcher.
+// Returns nil if the service or watcher is not initialized.
+func (s *Service) GetWatcher() *WatcherWrapper {
+ if s == nil {
+ return nil
+ }
+ return s.watcher
+}
+
// newDefaultAuthManager creates a default authentication manager with all supported providers.
func newDefaultAuthManager() *sdkAuth.Manager {
return sdkAuth.NewManager(
@@ -418,6 +429,12 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg))
case "kimi":
s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg))
+ case "kiro":
+ s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg))
+ case "kilo":
+ s.coreManager.RegisterExecutor(executor.NewKiloExecutor(s.cfg))
+ case "github-copilot":
+ s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
default:
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
if providerKey == "" {
@@ -618,6 +635,18 @@ func (s *Service) Run(ctx context.Context) error {
}
watcherWrapper.SetConfig(s.cfg)
+ // 方案 A: 连接 Kiro 后台刷新器回调到 Watcher
+ // 当后台刷新器成功刷新 token 后,立即通知 Watcher 更新内存中的 Auth 对象
+ // 这解决了后台刷新与内存 Auth 对象之间的时间差问题
+ kiroauth.GetRefreshManager().SetOnTokenRefreshed(func(tokenID string, tokenData *kiroauth.KiroTokenData) {
+ if tokenData == nil || watcherWrapper == nil {
+ return
+ }
+ log.Debugf("kiro refresh callback: notifying watcher for token %s", tokenID)
+ watcherWrapper.NotifyTokenRefreshed(tokenID, tokenData.AccessToken, tokenData.RefreshToken, tokenData.ExpiresAt)
+ })
+ log.Debug("kiro: connected background refresh callback to watcher")
+
watcherCtx, watcherCancel := context.WithCancel(context.Background())
s.watcherCancel = watcherCancel
if err = watcherWrapper.Start(watcherCtx); err != nil {
@@ -835,6 +864,15 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
models = applyExcludedModels(models, excluded)
case "kimi":
models = registry.GetKimiModels()
+ models = applyExcludedModels(models, excluded)
+ case "github-copilot":
+ models = registry.GetGitHubCopilotModels()
+ models = applyExcludedModels(models, excluded)
+ case "kiro":
+ models = s.fetchKiroModels(a)
+ models = applyExcludedModels(models, excluded)
+ case "kilo":
+ models = executor.FetchKiloModels(context.Background(), a, s.cfg)
models = applyExcludedModels(models, excluded)
default:
// Handle OpenAI-compatibility providers by name using config
@@ -1397,3 +1435,216 @@ func applyOAuthModelAlias(cfg *config.Config, provider, authKind string, models
}
return out
}
+
+// fetchKiroModels attempts to dynamically fetch Kiro models from the API.
+// If dynamic fetch fails, it falls back to static registry.GetKiroModels().
+func (s *Service) fetchKiroModels(a *coreauth.Auth) []*ModelInfo {
+ if a == nil {
+ log.Debug("kiro: auth is nil, using static models")
+ return registry.GetKiroModels()
+ }
+
+ // Extract token data from auth attributes
+ tokenData := s.extractKiroTokenData(a)
+ if tokenData == nil || tokenData.AccessToken == "" {
+ log.Debug("kiro: no valid token data in auth, using static models")
+ return registry.GetKiroModels()
+ }
+
+ // Create KiroAuth instance
+ kAuth := kiroauth.NewKiroAuth(s.cfg)
+ if kAuth == nil {
+ log.Warn("kiro: failed to create KiroAuth instance, using static models")
+ return registry.GetKiroModels()
+ }
+
+ // Use timeout context for API call
+ ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
+ defer cancel()
+
+ // Attempt to fetch dynamic models
+ apiModels, err := kAuth.ListAvailableModels(ctx, tokenData)
+ if err != nil {
+ log.Warnf("kiro: failed to fetch dynamic models: %v, using static models", err)
+ return registry.GetKiroModels()
+ }
+
+ if len(apiModels) == 0 {
+ log.Debug("kiro: API returned no models, using static models")
+ return registry.GetKiroModels()
+ }
+
+ // Convert API models to ModelInfo
+ models := convertKiroAPIModels(apiModels)
+
+ // Generate agentic variants
+ models = generateKiroAgenticVariants(models)
+
+ log.Infof("kiro: successfully fetched %d models from API (including agentic variants)", len(models))
+ return models
+}
+
+// extractKiroTokenData extracts KiroTokenData from auth attributes and metadata.
+// It supports both config-based tokens (stored in Attributes) and file-based tokens (stored in Metadata).
+func (s *Service) extractKiroTokenData(a *coreauth.Auth) *kiroauth.KiroTokenData {
+ if a == nil {
+ return nil
+ }
+
+ var accessToken, profileArn, refreshToken string
+
+ // Priority 1: Try to get from Attributes (config.yaml source)
+ if a.Attributes != nil {
+ accessToken = strings.TrimSpace(a.Attributes["access_token"])
+ profileArn = strings.TrimSpace(a.Attributes["profile_arn"])
+ refreshToken = strings.TrimSpace(a.Attributes["refresh_token"])
+ }
+
+ // Priority 2: If not found in Attributes, try Metadata (JSON file source)
+ if accessToken == "" && a.Metadata != nil {
+ if at, ok := a.Metadata["access_token"].(string); ok {
+ accessToken = strings.TrimSpace(at)
+ }
+ if pa, ok := a.Metadata["profile_arn"].(string); ok {
+ profileArn = strings.TrimSpace(pa)
+ }
+ if rt, ok := a.Metadata["refresh_token"].(string); ok {
+ refreshToken = strings.TrimSpace(rt)
+ }
+ }
+
+ // access_token is required
+ if accessToken == "" {
+ return nil
+ }
+
+ return &kiroauth.KiroTokenData{
+ AccessToken: accessToken,
+ ProfileArn: profileArn,
+ RefreshToken: refreshToken,
+ }
+}
+
+// convertKiroAPIModels converts Kiro API models to ModelInfo slice.
+func convertKiroAPIModels(apiModels []*kiroauth.KiroModel) []*ModelInfo {
+ if len(apiModels) == 0 {
+ return nil
+ }
+
+ now := time.Now().Unix()
+ models := make([]*ModelInfo, 0, len(apiModels))
+
+ for _, m := range apiModels {
+ if m == nil || m.ModelID == "" {
+ continue
+ }
+
+ // Create model ID with kiro- prefix
+ modelID := "kiro-" + normalizeKiroModelID(m.ModelID)
+
+ info := &ModelInfo{
+ ID: modelID,
+ Object: "model",
+ Created: now,
+ OwnedBy: "aws",
+ Type: "kiro",
+ DisplayName: formatKiroDisplayName(m.ModelName, m.RateMultiplier),
+ Description: m.Description,
+ ContextLength: 200000,
+ MaxCompletionTokens: 64000,
+ Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
+ }
+
+ if m.MaxInputTokens > 0 {
+ info.ContextLength = m.MaxInputTokens
+ }
+
+ models = append(models, info)
+ }
+
+ return models
+}
+
+// normalizeKiroModelID normalizes a Kiro model ID by converting dots to dashes
+// and removing common prefixes.
+func normalizeKiroModelID(modelID string) string {
+ // Remove common prefixes
+ modelID = strings.TrimPrefix(modelID, "anthropic.")
+ modelID = strings.TrimPrefix(modelID, "amazon.")
+
+ // Replace dots with dashes for consistency
+ modelID = strings.ReplaceAll(modelID, ".", "-")
+
+ // Replace underscores with dashes
+ modelID = strings.ReplaceAll(modelID, "_", "-")
+
+ return strings.ToLower(modelID)
+}
+
+// formatKiroDisplayName formats the display name with rate multiplier info.
+func formatKiroDisplayName(modelName string, rateMultiplier float64) string {
+ if modelName == "" {
+ return ""
+ }
+
+ displayName := "Kiro " + modelName
+ if rateMultiplier > 0 && rateMultiplier != 1.0 {
+ displayName += fmt.Sprintf(" (%.1fx credit)", rateMultiplier)
+ }
+
+ return displayName
+}
+
+// generateKiroAgenticVariants generates agentic variants for Kiro models.
+// Agentic variants have optimized system prompts for coding agents.
+func generateKiroAgenticVariants(models []*ModelInfo) []*ModelInfo {
+ if len(models) == 0 {
+ return models
+ }
+
+ result := make([]*ModelInfo, 0, len(models)*2)
+ result = append(result, models...)
+
+ for _, m := range models {
+ if m == nil {
+ continue
+ }
+
+ // Skip if already an agentic variant
+ if strings.HasSuffix(m.ID, "-agentic") {
+ continue
+ }
+
+ // Skip auto models from agentic variant generation
+ if strings.Contains(m.ID, "-auto") {
+ continue
+ }
+
+ // Create agentic variant
+ agentic := &ModelInfo{
+ ID: m.ID + "-agentic",
+ Object: m.Object,
+ Created: m.Created,
+ OwnedBy: m.OwnedBy,
+ Type: m.Type,
+ DisplayName: m.DisplayName + " (Agentic)",
+ Description: m.Description + " - Optimized for coding agents (chunked writes)",
+ ContextLength: m.ContextLength,
+ MaxCompletionTokens: m.MaxCompletionTokens,
+ }
+
+ // Copy thinking support if present
+ if m.Thinking != nil {
+ agentic.Thinking = ®istry.ThinkingSupport{
+ Min: m.Thinking.Min,
+ Max: m.Thinking.Max,
+ ZeroAllowed: m.Thinking.ZeroAllowed,
+ DynamicAllowed: m.Thinking.DynamicAllowed,
+ }
+ }
+
+ result = append(result, agentic)
+ }
+
+ return result
+}
diff --git a/sdk/cliproxy/types.go b/sdk/cliproxy/types.go
index 1521dffe..ee8f761d 100644
--- a/sdk/cliproxy/types.go
+++ b/sdk/cliproxy/types.go
@@ -89,6 +89,7 @@ type WatcherWrapper struct {
snapshotAuths func() []*coreauth.Auth
setUpdateQueue func(queue chan<- watcher.AuthUpdate)
dispatchRuntimeUpdate func(update watcher.AuthUpdate) bool
+ notifyTokenRefreshed func(tokenID, accessToken, refreshToken, expiresAt string) // 方案 A: 后台刷新通知
}
// Start proxies to the underlying watcher Start implementation.
@@ -146,3 +147,16 @@ func (w *WatcherWrapper) SetAuthUpdateQueue(queue chan<- watcher.AuthUpdate) {
}
w.setUpdateQueue(queue)
}
+
+// NotifyTokenRefreshed 通知 Watcher 后台刷新器已更新 token
+// 这是方案 A 的核心方法,用于解决后台刷新与内存 Auth 对象的时间差问题
+// tokenID: token 文件名(如 kiro-xxx.json)
+// accessToken: 新的 access token
+// refreshToken: 新的 refresh token
+// expiresAt: 新的过期时间(RFC3339 格式)
+func (w *WatcherWrapper) NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt string) {
+ if w == nil || w.notifyTokenRefreshed == nil {
+ return
+ }
+ w.notifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt)
+}
diff --git a/sdk/cliproxy/watcher.go b/sdk/cliproxy/watcher.go
index caeadf19..e6e91bdd 100644
--- a/sdk/cliproxy/watcher.go
+++ b/sdk/cliproxy/watcher.go
@@ -31,5 +31,8 @@ func defaultWatcherFactory(configPath, authDir string, reload func(*config.Confi
dispatchRuntimeUpdate: func(update watcher.AuthUpdate) bool {
return w.DispatchRuntimeAuthUpdate(update)
},
+ notifyTokenRefreshed: func(tokenID, accessToken, refreshToken, expiresAt string) {
+ w.NotifyTokenRefreshed(tokenID, accessToken, refreshToken, expiresAt)
+ },
}, nil
}
diff --git a/test/thinking_conversion_test.go b/test/thinking_conversion_test.go
index 781a1667..e7beb1a3 100644
--- a/test/thinking_conversion_test.go
+++ b/test/thinking_conversion_test.go
@@ -1316,6 +1316,122 @@ func TestThinkingE2EMatrix_Suffix(t *testing.T) {
includeThoughts: "true",
expectErr: false,
},
+
+ // GitHub Copilot tests: gpt-5, gpt-5.1, gpt-5.2 (Levels=low/medium/high, some with none/xhigh)
+ // Testing /chat/completions endpoint (openai format) - with suffix
+
+ // Case 112: OpenAI to gpt-5, level high → high
+ {
+ name: "112",
+ from: "openai",
+ to: "github-copilot",
+ model: "gpt-5(high)",
+ inputJSON: `{"model":"gpt-5(high)","messages":[{"role":"user","content":"hi"}]}`,
+ expectField: "reasoning_effort",
+ expectValue: "high",
+ expectErr: false,
+ },
+ // Case 113: OpenAI to gpt-5, level none → clamped to low (ZeroAllowed=false)
+ {
+ name: "113",
+ from: "openai",
+ to: "github-copilot",
+ model: "gpt-5(none)",
+ inputJSON: `{"model":"gpt-5(none)","messages":[{"role":"user","content":"hi"}]}`,
+ expectField: "reasoning_effort",
+ expectValue: "low",
+ expectErr: false,
+ },
+ // Case 114: OpenAI to gpt-5.1, level none → none (ZeroAllowed=true)
+ {
+ name: "114",
+ from: "openai",
+ to: "github-copilot",
+ model: "gpt-5.1(none)",
+ inputJSON: `{"model":"gpt-5.1(none)","messages":[{"role":"user","content":"hi"}]}`,
+ expectField: "reasoning_effort",
+ expectValue: "none",
+ expectErr: false,
+ },
+ // Case 115: OpenAI to gpt-5.2, level xhigh → xhigh
+ {
+ name: "115",
+ from: "openai",
+ to: "github-copilot",
+ model: "gpt-5.2(xhigh)",
+ inputJSON: `{"model":"gpt-5.2(xhigh)","messages":[{"role":"user","content":"hi"}]}`,
+ expectField: "reasoning_effort",
+ expectValue: "xhigh",
+ expectErr: false,
+ },
+ // Case 116: OpenAI to gpt-5, level xhigh (out of range) → error
+ {
+ name: "116",
+ from: "openai",
+ to: "github-copilot",
+ model: "gpt-5(xhigh)",
+ inputJSON: `{"model":"gpt-5(xhigh)","messages":[{"role":"user","content":"hi"}]}`,
+ expectField: "",
+ expectErr: true,
+ },
+ // Case 117: Claude to gpt-5.1, budget 0 → none (ZeroAllowed=true)
+ {
+ name: "117",
+ from: "claude",
+ to: "github-copilot",
+ model: "gpt-5.1(0)",
+ inputJSON: `{"model":"gpt-5.1(0)","messages":[{"role":"user","content":"hi"}]}`,
+ expectField: "reasoning_effort",
+ expectValue: "none",
+ expectErr: false,
+ },
+
+ // GitHub Copilot tests: /responses endpoint (codex format) - with suffix
+
+ // Case 118: OpenAI-Response to gpt-5-codex, level high → high
+ {
+ name: "118",
+ from: "openai-response",
+ to: "github-copilot",
+ model: "gpt-5-codex(high)",
+ inputJSON: `{"model":"gpt-5-codex(high)","input":[{"role":"user","content":"hi"}]}`,
+ expectField: "reasoning.effort",
+ expectValue: "high",
+ expectErr: false,
+ },
+ // Case 119: OpenAI-Response to gpt-5.2-codex, level xhigh → xhigh
+ {
+ name: "119",
+ from: "openai-response",
+ to: "github-copilot",
+ model: "gpt-5.2-codex(xhigh)",
+ inputJSON: `{"model":"gpt-5.2-codex(xhigh)","input":[{"role":"user","content":"hi"}]}`,
+ expectField: "reasoning.effort",
+ expectValue: "xhigh",
+ expectErr: false,
+ },
+ // Case 120: OpenAI-Response to gpt-5.2-codex, level none → none
+ {
+ name: "120",
+ from: "openai-response",
+ to: "github-copilot",
+ model: "gpt-5.2-codex(none)",
+ inputJSON: `{"model":"gpt-5.2-codex(none)","input":[{"role":"user","content":"hi"}]}`,
+ expectField: "reasoning.effort",
+ expectValue: "none",
+ expectErr: false,
+ },
+ // Case 121: OpenAI-Response to gpt-5-codex, level none → clamped to low (ZeroAllowed=false)
+ {
+ name: "121",
+ from: "openai-response",
+ to: "github-copilot",
+ model: "gpt-5-codex(none)",
+ inputJSON: `{"model":"gpt-5-codex(none)","input":[{"role":"user","content":"hi"}]}`,
+ expectField: "reasoning.effort",
+ expectValue: "low",
+ expectErr: false,
+ },
}
runThinkingTests(t, cases)
@@ -2585,6 +2701,122 @@ func TestThinkingE2EMatrix_Body(t *testing.T) {
includeThoughts: "true",
expectErr: false,
},
+
+ // GitHub Copilot tests: gpt-5, gpt-5.1, gpt-5.2 (Levels=low/medium/high, some with none/xhigh)
+ // Testing /chat/completions endpoint (openai format) - with body params
+
+ // Case 112: OpenAI to gpt-5, reasoning_effort=high → high
+ {
+ name: "112",
+ from: "openai",
+ to: "github-copilot",
+ model: "gpt-5",
+ inputJSON: `{"model":"gpt-5","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"high"}`,
+ expectField: "reasoning_effort",
+ expectValue: "high",
+ expectErr: false,
+ },
+ // Case 113: OpenAI to gpt-5, reasoning_effort=none → clamped to low (ZeroAllowed=false)
+ {
+ name: "113",
+ from: "openai",
+ to: "github-copilot",
+ model: "gpt-5",
+ inputJSON: `{"model":"gpt-5","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"none"}`,
+ expectField: "reasoning_effort",
+ expectValue: "low",
+ expectErr: false,
+ },
+ // Case 114: OpenAI to gpt-5.1, reasoning_effort=none → none (ZeroAllowed=true)
+ {
+ name: "114",
+ from: "openai",
+ to: "github-copilot",
+ model: "gpt-5.1",
+ inputJSON: `{"model":"gpt-5.1","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"none"}`,
+ expectField: "reasoning_effort",
+ expectValue: "none",
+ expectErr: false,
+ },
+ // Case 115: OpenAI to gpt-5.2, reasoning_effort=xhigh → xhigh
+ {
+ name: "115",
+ from: "openai",
+ to: "github-copilot",
+ model: "gpt-5.2",
+ inputJSON: `{"model":"gpt-5.2","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"xhigh"}`,
+ expectField: "reasoning_effort",
+ expectValue: "xhigh",
+ expectErr: false,
+ },
+ // Case 116: OpenAI to gpt-5, reasoning_effort=xhigh (out of range) → error
+ {
+ name: "116",
+ from: "openai",
+ to: "github-copilot",
+ model: "gpt-5",
+ inputJSON: `{"model":"gpt-5","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"xhigh"}`,
+ expectField: "",
+ expectErr: true,
+ },
+ // Case 117: Claude to gpt-5.1, thinking.budget_tokens=0 → none (ZeroAllowed=true)
+ {
+ name: "117",
+ from: "claude",
+ to: "github-copilot",
+ model: "gpt-5.1",
+ inputJSON: `{"model":"gpt-5.1","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":0}}`,
+ expectField: "reasoning_effort",
+ expectValue: "none",
+ expectErr: false,
+ },
+
+ // GitHub Copilot tests: /responses endpoint (codex format) - with body params
+
+ // Case 118: OpenAI-Response to gpt-5-codex, reasoning.effort=high → high
+ {
+ name: "118",
+ from: "openai-response",
+ to: "github-copilot",
+ model: "gpt-5-codex",
+ inputJSON: `{"model":"gpt-5-codex","input":[{"role":"user","content":"hi"}],"reasoning":{"effort":"high"}}`,
+ expectField: "reasoning.effort",
+ expectValue: "high",
+ expectErr: false,
+ },
+ // Case 119: OpenAI-Response to gpt-5.2-codex, reasoning.effort=xhigh → xhigh
+ {
+ name: "119",
+ from: "openai-response",
+ to: "github-copilot",
+ model: "gpt-5.2-codex",
+ inputJSON: `{"model":"gpt-5.2-codex","input":[{"role":"user","content":"hi"}],"reasoning":{"effort":"xhigh"}}`,
+ expectField: "reasoning.effort",
+ expectValue: "xhigh",
+ expectErr: false,
+ },
+ // Case 120: OpenAI-Response to gpt-5.2-codex, reasoning.effort=none → none
+ {
+ name: "120",
+ from: "openai-response",
+ to: "github-copilot",
+ model: "gpt-5.2-codex",
+ inputJSON: `{"model":"gpt-5.2-codex","input":[{"role":"user","content":"hi"}],"reasoning":{"effort":"none"}}`,
+ expectField: "reasoning.effort",
+ expectValue: "none",
+ expectErr: false,
+ },
+ // Case 121: OpenAI-Response to gpt-5-codex, reasoning.effort=none → clamped to low (ZeroAllowed=false)
+ {
+ name: "121",
+ from: "openai-response",
+ to: "github-copilot",
+ model: "gpt-5-codex",
+ inputJSON: `{"model":"gpt-5-codex","input":[{"role":"user","content":"hi"}],"reasoning":{"effort":"none"}}`,
+ expectField: "reasoning.effort",
+ expectValue: "low",
+ expectErr: false,
+ },
}
runThinkingTests(t, cases)
@@ -2813,6 +3045,51 @@ func getTestModels() []*registry.ModelInfo {
DisplayName: "MiniMax Test Model",
Thinking: ®istry.ThinkingSupport{Levels: []string{"none", "auto", "minimal", "low", "medium", "high", "xhigh"}},
},
+ {
+ ID: "gpt-5",
+ Object: "model",
+ Created: 1700000000,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "GPT-5",
+ Thinking: ®istry.ThinkingSupport{Levels: []string{"low", "medium", "high"}, ZeroAllowed: false, DynamicAllowed: false},
+ },
+ {
+ ID: "gpt-5.1",
+ Object: "model",
+ Created: 1700000000,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "GPT-5.1",
+ Thinking: ®istry.ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}, ZeroAllowed: true, DynamicAllowed: false},
+ },
+ {
+ ID: "gpt-5.2",
+ Object: "model",
+ Created: 1700000000,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "GPT-5.2",
+ Thinking: ®istry.ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}, ZeroAllowed: true, DynamicAllowed: false},
+ },
+ {
+ ID: "gpt-5-codex",
+ Object: "model",
+ Created: 1700000000,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "GPT-5 Codex",
+ Thinking: ®istry.ThinkingSupport{Levels: []string{"low", "medium", "high"}, ZeroAllowed: false, DynamicAllowed: false},
+ },
+ {
+ ID: "gpt-5.2-codex",
+ Object: "model",
+ Created: 1700000000,
+ OwnedBy: "github-copilot",
+ Type: "github-copilot",
+ DisplayName: "GPT-5.2 Codex",
+ Thinking: ®istry.ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}, ZeroAllowed: true, DynamicAllowed: false},
+ },
}
}
@@ -2831,6 +3108,15 @@ func runThinkingTests(t *testing.T, cases []thinkingTestCase) {
translateTo = "openai"
applyTo = "iflow"
}
+ if tc.to == "github-copilot" {
+ if tc.from == "openai-response" {
+ translateTo = "codex"
+ applyTo = "codex"
+ } else {
+ translateTo = "openai"
+ applyTo = "openai"
+ }
+ }
body := sdktranslator.TranslateRequest(
sdktranslator.FromString(tc.from),