diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 3aacf4f5..de924672 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -7,7 +7,7 @@ on: env: APP_NAME: CLIProxyAPI - DOCKERHUB_REPO: eceasy/cli-proxy-api + DOCKERHUB_REPO: eceasy/cli-proxy-api-plus jobs: docker: diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 4bb5e63b..4c4aafe7 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -23,7 +23,8 @@ jobs: cache: true - name: Generate Build Metadata run: | - echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV + VERSION=$(git describe --tags --always --dirty) + echo "VERSION=${VERSION}" >> $GITHUB_ENV echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV - uses: goreleaser/goreleaser-action@v4 diff --git a/.gitignore b/.gitignore index 14395e6d..78c9de68 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Binaries cli-proxy-api +cliproxy *.exe # Configuration @@ -32,6 +33,7 @@ GEMINI.md .claude/* .serena/* .bmad/* +.mcp/cache/ # macOS .DS_Store diff --git a/.goreleaser.yml b/.goreleaser.yml index 31d05e6d..6e1829ed 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -1,5 +1,5 @@ builds: - - id: "cli-proxy-api" + - id: "cli-proxy-api-plus" env: - CGO_ENABLED=0 goos: @@ -10,11 +10,11 @@ builds: - amd64 - arm64 main: ./cmd/server/ - binary: cli-proxy-api + binary: cli-proxy-api-plus ldflags: - - -s -w -X 'main.Version={{.Version}}' -X 'main.Commit={{.ShortCommit}}' -X 'main.BuildDate={{.Date}}' + - -s -w -X 'main.Version={{.Version}}-plus' -X 'main.Commit={{.ShortCommit}}' -X 'main.BuildDate={{.Date}}' archives: - - id: "cli-proxy-api" + - id: "cli-proxy-api-plus" format: tar.gz format_overrides: - goos: windows diff --git a/Dockerfile b/Dockerfile index 8623dc5e..98509423 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,7 +12,7 @@ ARG VERSION=dev ARG COMMIT=none ARG BUILD_DATE=unknown -RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X 'main.Version=${VERSION}' -X 'main.Commit=${COMMIT}' -X 'main.BuildDate=${BUILD_DATE}'" -o ./CLIProxyAPI ./cmd/server/ +RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X 'main.Version=${VERSION}-plus' -X 'main.Commit=${COMMIT}' -X 'main.BuildDate=${BUILD_DATE}'" -o ./CLIProxyAPIPlus ./cmd/server/ FROM alpine:3.22.0 @@ -20,7 +20,7 @@ RUN apk add --no-cache tzdata RUN mkdir /CLIProxyAPI -COPY --from=builder ./app/CLIProxyAPI /CLIProxyAPI/CLIProxyAPI +COPY --from=builder ./app/CLIProxyAPIPlus /CLIProxyAPI/CLIProxyAPIPlus COPY config.example.yaml /CLIProxyAPI/config.example.yaml @@ -32,4 +32,4 @@ ENV TZ=Asia/Shanghai RUN cp /usr/share/zoneinfo/${TZ} /etc/localtime && echo "${TZ}" > /etc/timezone -CMD ["./CLIProxyAPI"] \ No newline at end of file +CMD ["./CLIProxyAPIPlus"] \ No newline at end of file diff --git a/README.md b/README.md index 03121c7b..d00e91c9 100644 --- a/README.md +++ b/README.md @@ -1,106 +1,23 @@ -# CLI Proxy API +# CLIProxyAPI Plus -English | [中文](README_CN.md) +English | [Chinese](README_CN.md) -A proxy server that provides OpenAI/Gemini/Claude/Codex compatible API interfaces for CLI. +This is the Plus version of [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI), adding support for third-party providers on top of the mainline project. -It now also supports OpenAI Codex (GPT models) and Claude Code via OAuth. +All third-party provider support is maintained by community contributors; CLIProxyAPI does not provide technical support. Please contact the corresponding community maintainer if you need assistance. -So you can use local or multi-account CLI access with OpenAI(include Responses)/Gemini/Claude-compatible clients and SDKs. +The Plus release stays in lockstep with the mainline features. -## Sponsor +## Differences from the Mainline -[![z.ai](https://assets.router-for.me/english.png)](https://z.ai/subscribe?ic=8JVLJQFSKB) - -This project is sponsored by Z.ai, supporting us with their GLM CODING PLAN. - -GLM CODING PLAN is a subscription service designed for AI coding, starting at just $3/month. It provides access to their flagship GLM-4.6 model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences. - -Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB - -## Overview - -- OpenAI/Gemini/Claude compatible API endpoints for CLI models -- OpenAI Codex support (GPT models) via OAuth login -- Claude Code support via OAuth login -- Qwen Code support via OAuth login -- iFlow support via OAuth login -- Amp CLI and IDE extensions support with provider routing -- Streaming and non-streaming responses -- Function calling/tools support -- Multimodal input support (text and images) -- Multiple accounts with round-robin load balancing (Gemini, OpenAI, Claude, Qwen and iFlow) -- Simple CLI authentication flows (Gemini, OpenAI, Claude, Qwen and iFlow) -- Generative Language API Key support -- AI Studio Build multi-account load balancing -- Gemini CLI multi-account load balancing -- Claude Code multi-account load balancing -- Qwen Code multi-account load balancing -- iFlow multi-account load balancing -- OpenAI Codex multi-account load balancing -- OpenAI-compatible upstream providers via config (e.g., OpenRouter) -- Reusable Go SDK for embedding the proxy (see `docs/sdk-usage.md`) - -## Getting Started - -CLIProxyAPI Guides: [https://help.router-for.me/](https://help.router-for.me/) - -## Management API - -see [MANAGEMENT_API.md](https://help.router-for.me/management/api) - -## Amp CLI Support - -CLIProxyAPI includes integrated support for [Amp CLI](https://ampcode.com) and Amp IDE extensions, enabling you to use your Google/ChatGPT/Claude OAuth subscriptions with Amp's coding tools: - -- Provider route aliases for Amp's API patterns (`/api/provider/{provider}/v1...`) -- Management proxy for OAuth authentication and account features -- Smart model fallback with automatic routing -- **Model mapping** to route unavailable models to alternatives (e.g., `claude-opus-4.5` → `claude-sonnet-4`) -- Security-first design with localhost-only management endpoints - -**→ [Complete Amp CLI Integration Guide](https://help.router-for.me/agent-client/amp-cli.html)** - -## SDK Docs - -- Usage: [docs/sdk-usage.md](docs/sdk-usage.md) -- Advanced (executors & translators): [docs/sdk-advanced.md](docs/sdk-advanced.md) -- Access: [docs/sdk-access.md](docs/sdk-access.md) -- Watcher: [docs/sdk-watcher.md](docs/sdk-watcher.md) -- Custom Provider Example: `examples/custom-provider` +- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth) +- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/) ## Contributing -Contributions are welcome! Please feel free to submit a Pull Request. +This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected. -1. Fork the repository -2. Create your feature branch (`git checkout -b feature/amazing-feature`) -3. Commit your changes (`git commit -m 'Add some amazing feature'`) -4. Push to the branch (`git push origin feature/amazing-feature`) -5. Open a Pull Request - -## Who is with us? - -Those projects are based on CLIProxyAPI: - -### [vibeproxy](https://github.com/automazeio/vibeproxy) - -Native macOS menu bar app to use your Claude Code & ChatGPT subscriptions with AI coding tools - no API keys needed - -### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator) - -Browser-based tool to translate SRT subtitles using your Gemini subscription via CLIProxyAPI with automatic validation/error correction - no API keys needed - -### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs) - -CLI wrapper for instant switching between multiple Claude accounts and alternative models (Gemini, Codex, Antigravity) via CLIProxyAPI OAuth - no API keys needed - -### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal) - -Native macOS GUI for managing CLIProxyAPI: configure providers, model mappings, and endpoints via OAuth - no API keys needed. - -> [!NOTE] -> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list. +If you need to submit any non-third-party provider changes, please open them against the mainline repository. ## License diff --git a/README_CN.md b/README_CN.md index 4624e88f..21132b86 100644 --- a/README_CN.md +++ b/README_CN.md @@ -1,113 +1,24 @@ -# CLI 代理 API +# CLIProxyAPI Plus [English](README.md) | 中文 -一个为 CLI 提供 OpenAI/Gemini/Claude/Codex 兼容 API 接口的代理服务器。 +这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。 -现已支持通过 OAuth 登录接入 OpenAI Codex(GPT 系列)和 Claude Code。 +所有的第三方供应商支持都由第三方社区维护者提供,CLIProxyAPI 不提供技术支持。如需取得支持,请与对应的社区维护者联系。 -您可以使用本地或多账户的CLI方式,通过任何与 OpenAI(包括Responses)/Gemini/Claude 兼容的客户端和SDK进行访问。 +该 Plus 版本的主线功能与主线功能强制同步。 -## 赞助商 +## 与主线版本版本差异 -[![bigmodel.cn](https://assets.router-for.me/chinese.png)](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII) - -本项目由 Z智谱 提供赞助, 他们通过 GLM CODING PLAN 对本项目提供技术支持。 - -GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.6,为开发者提供顶尖的编码体验。 - -智谱AI为本软件提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII - -## 功能特性 - -- 为 CLI 模型提供 OpenAI/Gemini/Claude/Codex 兼容的 API 端点 -- 新增 OpenAI Codex(GPT 系列)支持(OAuth 登录) -- 新增 Claude Code 支持(OAuth 登录) -- 新增 Qwen Code 支持(OAuth 登录) -- 新增 iFlow 支持(OAuth 登录) -- 支持流式与非流式响应 -- 函数调用/工具支持 -- 多模态输入(文本、图片) -- 多账户支持与轮询负载均衡(Gemini、OpenAI、Claude、Qwen 与 iFlow) -- 简单的 CLI 身份验证流程(Gemini、OpenAI、Claude、Qwen 与 iFlow) -- 支持 Gemini AIStudio API 密钥 -- 支持 AI Studio Build 多账户轮询 -- 支持 Gemini CLI 多账户轮询 -- 支持 Claude Code 多账户轮询 -- 支持 Qwen Code 多账户轮询 -- 支持 iFlow 多账户轮询 -- 支持 OpenAI Codex 多账户轮询 -- 通过配置接入上游 OpenAI 兼容提供商(例如 OpenRouter) -- 可复用的 Go SDK(见 `docs/sdk-usage_CN.md`) - -## 新手入门 - -CLIProxyAPI 用户手册: [https://help.router-for.me/](https://help.router-for.me/cn/) - -## 管理 API 文档 - -请参见 [MANAGEMENT_API_CN.md](https://help.router-for.me/cn/management/api) - -## Amp CLI 支持 - -CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支持,可让你使用自己的 Google/ChatGPT/Claude OAuth 订阅来配合 Amp 编码工具: - -- 提供商路由别名,兼容 Amp 的 API 路径模式(`/api/provider/{provider}/v1...`) -- 管理代理,处理 OAuth 认证和账号功能 -- 智能模型回退与自动路由 -- 以安全为先的设计,管理端点仅限 localhost - -**→ [Amp CLI 完整集成指南](https://help.router-for.me/cn/agent-client/amp-cli.html)** - -## SDK 文档 - -- 使用文档:[docs/sdk-usage_CN.md](docs/sdk-usage_CN.md) -- 高级(执行器与翻译器):[docs/sdk-advanced_CN.md](docs/sdk-advanced_CN.md) -- 认证: [docs/sdk-access_CN.md](docs/sdk-access_CN.md) -- 凭据加载/更新: [docs/sdk-watcher_CN.md](docs/sdk-watcher_CN.md) -- 自定义 Provider 示例:`examples/custom-provider` +- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供 +- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供 ## 贡献 -欢迎贡献!请随时提交 Pull Request。 +该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。 -1. Fork 仓库 -2. 创建您的功能分支(`git checkout -b feature/amazing-feature`) -3. 提交您的更改(`git commit -m 'Add some amazing feature'`) -4. 推送到分支(`git push origin feature/amazing-feature`) -5. 打开 Pull Request - -## 谁与我们在一起? - -这些项目基于 CLIProxyAPI: - -### [vibeproxy](https://github.com/automazeio/vibeproxy) - -一个原生 macOS 菜单栏应用,让您可以使用 Claude Code & ChatGPT 订阅服务和 AI 编程工具,无需 API 密钥。 - -### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator) - -一款基于浏览器的 SRT 字幕翻译工具,可通过 CLI 代理 API 使用您的 Gemini 订阅。内置自动验证与错误修正功能,无需 API 密钥。 - -### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs) - -CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户和替代模型(Gemini, Codex, Antigravity),无需 API 密钥。 - -### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal) - -基于 macOS 平台的原生 CLIProxyAPI GUI:配置供应商、模型映射以及OAuth端点,无需 API 密钥。 - -> [!NOTE] -> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。 +如果需要提交任何非第三方供应商支持的 Pull Request,请提交到主线版本。 ## 许可证 -此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。 - -## 写给所有中国网友的 - -QQ 群:188637136 - -或 - -Telegram 群:https://t.me/CLIProxyAPI +此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。 \ No newline at end of file diff --git a/cmd/server/main.go b/cmd/server/main.go index aec51ab8..fe648f6c 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -47,6 +47,19 @@ func init() { buildinfo.BuildDate = BuildDate } +// setKiroIncognitoMode sets the incognito browser mode for Kiro authentication. +// Kiro defaults to incognito mode for multi-account support. +// Users can explicitly override with --incognito or --no-incognito flags. +func setKiroIncognitoMode(cfg *config.Config, useIncognito, noIncognito bool) { + if useIncognito { + cfg.IncognitoBrowser = true + } else if noIncognito { + cfg.IncognitoBrowser = false + } else { + cfg.IncognitoBrowser = true // Kiro default + } +} + // main is the entry point of the application. // It parses command-line flags, loads configuration, and starts the appropriate // service based on the provided flags (login, codex-login, or server mode). @@ -62,10 +75,17 @@ func main() { var iflowCookie bool var noBrowser bool var antigravityLogin bool + var kiroLogin bool + var kiroGoogleLogin bool + var kiroAWSLogin bool + var kiroImport bool + var githubCopilotLogin bool var projectID string var vertexImport string var configPath string var password string + var noIncognito bool + var useIncognito bool // Define command-line flags for different operation modes. flag.BoolVar(&login, "login", false, "Login Google Account") @@ -75,7 +95,14 @@ func main() { flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth") flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie") flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth") + flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)") + flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)") flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth") + flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth") + flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)") + flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)") + flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)") + flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow") flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)") flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path") flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file") @@ -453,6 +480,9 @@ func main() { } else if antigravityLogin { // Handle Antigravity login cmd.DoAntigravityLogin(cfg, options) + } else if githubCopilotLogin { + // Handle GitHub Copilot login + cmd.DoGitHubCopilotLogin(cfg, options) } else if codexLogin { // Handle Codex login cmd.DoCodexLogin(cfg, options) @@ -465,6 +495,26 @@ func main() { cmd.DoIFlowLogin(cfg, options) } else if iflowCookie { cmd.DoIFlowCookieAuth(cfg, options) + } else if kiroLogin { + // For Kiro auth, default to incognito mode for multi-account support + // Users can explicitly override with --no-incognito + // Note: This config mutation is safe - auth commands exit after completion + // and don't share config with StartService (which is in the else branch) + setKiroIncognitoMode(cfg, useIncognito, noIncognito) + cmd.DoKiroLogin(cfg, options) + } else if kiroGoogleLogin { + // For Kiro auth, default to incognito mode for multi-account support + // Users can explicitly override with --no-incognito + // Note: This config mutation is safe - auth commands exit after completion + setKiroIncognitoMode(cfg, useIncognito, noIncognito) + cmd.DoKiroGoogleLogin(cfg, options) + } else if kiroAWSLogin { + // For Kiro auth, default to incognito mode for multi-account support + // Users can explicitly override with --no-incognito + setKiroIncognitoMode(cfg, useIncognito, noIncognito) + cmd.DoKiroAWSLogin(cfg, options) + } else if kiroImport { + cmd.DoKiroImport(cfg, options) } else { // In cloud deploy mode without config file, just wait for shutdown signals if isCloudDeploy && !configFileExists { diff --git a/config.example.yaml b/config.example.yaml index e93d71b6..922fcd96 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -39,6 +39,11 @@ api-keys: # Enable debug logging debug: false +# Open OAuth URLs in incognito/private browser mode. +# Useful when you want to login with a different account without logging out from your current session. +# Default: false (but Kiro auth defaults to true for multi-account support) +incognito-browser: true + # When true, write application logs to rotating files instead of stdout logging-to-file: false @@ -106,6 +111,16 @@ ws-auth: false # - "*-thinking" # wildcard matching suffix (e.g. claude-opus-4-5-thinking) # - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022) +# Kiro (AWS CodeWhisperer) configuration +# Note: Kiro API currently only operates in us-east-1 region +#kiro: +# - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file +# agent-task-type: "" # optional: "vibe" or empty (API default) +# - access-token: "aoaAAAAA..." # or provide tokens directly +# refresh-token: "aorAAAAA..." +# profile-arn: "arn:aws:codewhisperer:us-east-1:..." +# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override + # OpenAI compatibility providers # openai-compatibility: # - name: "openrouter" # The name of the provider; it will be used in the user agent and other places. diff --git a/docker-compose.yml b/docker-compose.yml index 29712419..80693fbe 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,6 +1,6 @@ services: cli-proxy-api: - image: ${CLI_PROXY_IMAGE:-eceasy/cli-proxy-api:latest} + image: ${CLI_PROXY_IMAGE:-eceasy/cli-proxy-api-plus:latest} pull_policy: always build: context: . @@ -9,7 +9,7 @@ services: VERSION: ${VERSION:-dev} COMMIT: ${COMMIT:-none} BUILD_DATE: ${BUILD_DATE:-unknown} - container_name: cli-proxy-api + container_name: cli-proxy-api-plus # env_file: # - .env environment: diff --git a/go.mod b/go.mod index 963d9c49..632ac35a 100644 --- a/go.mod +++ b/go.mod @@ -13,14 +13,15 @@ require ( github.com/joho/godotenv v1.5.1 github.com/klauspost/compress v1.17.4 github.com/minio/minio-go/v7 v7.0.66 + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/sirupsen/logrus v1.9.3 - github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/tiktoken-go/tokenizer v0.7.0 golang.org/x/crypto v0.45.0 golang.org/x/net v0.47.0 golang.org/x/oauth2 v0.30.0 + golang.org/x/term v0.36.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 4705336b..d4a4cb9d 100644 --- a/go.sum +++ b/go.sum @@ -116,6 +116,8 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6 github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0= github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= @@ -126,8 +128,6 @@ github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw= github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA= -github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -169,6 +169,7 @@ golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKl golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index e29dc29f..0455a62e 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -3,6 +3,9 @@ package management import ( "bytes" "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -23,6 +26,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini" iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow" + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" @@ -37,9 +41,32 @@ import ( ) var ( - oauthStatus = make(map[string]string) + oauthStatus = make(map[string]string) + oauthStatusMutex sync.RWMutex ) +// getOAuthStatus safely retrieves an OAuth status +func getOAuthStatus(key string) (string, bool) { + oauthStatusMutex.RLock() + defer oauthStatusMutex.RUnlock() + status, ok := oauthStatus[key] + return status, ok +} + +// setOAuthStatus safely sets an OAuth status +func setOAuthStatus(key string, status string) { + oauthStatusMutex.Lock() + defer oauthStatusMutex.Unlock() + oauthStatus[key] = status +} + +// deleteOAuthStatus safely deletes an OAuth status +func deleteOAuthStatus(key string) { + oauthStatusMutex.Lock() + defer oauthStatusMutex.Unlock() + delete(oauthStatus, key) +} + var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"} const ( @@ -812,7 +839,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { deadline := time.Now().Add(timeout) for { if time.Now().After(deadline) { - oauthStatus[state] = "Timeout waiting for OAuth callback" + setOAuthStatus(state, "Timeout waiting for OAuth callback") return nil, fmt.Errorf("timeout waiting for OAuth callback") } data, errRead := os.ReadFile(path) @@ -837,13 +864,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { if errStr := resultMap["error"]; errStr != "" { oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest) log.Error(claude.GetUserFriendlyMessage(oauthErr)) - oauthStatus[state] = "Bad request" + setOAuthStatus(state, "Bad request") return } if resultMap["state"] != state { authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"])) log.Error(claude.GetUserFriendlyMessage(authErr)) - oauthStatus[state] = "State code error" + setOAuthStatus(state, "State code error") return } @@ -876,7 +903,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { if errDo != nil { authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo) log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) - oauthStatus[state] = "Failed to exchange authorization code for tokens" + setOAuthStatus(state, "Failed to exchange authorization code for tokens") return } defer func() { @@ -887,7 +914,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) - oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode) + setOAuthStatus(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)) return } var tResp struct { @@ -900,7 +927,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { } if errU := json.Unmarshal(respBody, &tResp); errU != nil { log.Errorf("failed to parse token response: %v", errU) - oauthStatus[state] = "Failed to parse token response" + setOAuthStatus(state, "Failed to parse token response") return } bundle := &claude.ClaudeAuthBundle{ @@ -925,7 +952,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save authentication tokens: %v", errSave) - oauthStatus[state] = "Failed to save authentication tokens" + setOAuthStatus(state, "Failed to save authentication tokens") return } @@ -934,10 +961,10 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use Claude services through this CLI") - delete(oauthStatus, state) + deleteOAuthStatus(state) }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -996,7 +1023,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { for { if time.Now().After(deadline) { log.Error("oauth flow timed out") - oauthStatus[state] = "OAuth flow timed out" + setOAuthStatus(state, "OAuth flow timed out") return } if data, errR := os.ReadFile(waitFile); errR == nil { @@ -1005,13 +1032,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { _ = os.Remove(waitFile) if errStr := m["error"]; errStr != "" { log.Errorf("Authentication failed: %s", errStr) - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") return } authCode = m["code"] if authCode == "" { log.Errorf("Authentication failed: code not found") - oauthStatus[state] = "Authentication failed: code not found" + setOAuthStatus(state, "Authentication failed: code not found") return } break @@ -1023,7 +1050,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { token, err := conf.Exchange(ctx, authCode) if err != nil { log.Errorf("Failed to exchange token: %v", err) - oauthStatus[state] = "Failed to exchange token" + setOAuthStatus(state, "Failed to exchange token") return } @@ -1034,7 +1061,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if errNewRequest != nil { log.Errorf("Could not get user info: %v", errNewRequest) - oauthStatus[state] = "Could not get user info" + setOAuthStatus(state, "Could not get user info") return } req.Header.Set("Content-Type", "application/json") @@ -1043,7 +1070,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { resp, errDo := authHTTPClient.Do(req) if errDo != nil { log.Errorf("Failed to execute request: %v", errDo) - oauthStatus[state] = "Failed to execute request" + setOAuthStatus(state, "Failed to execute request") return } defer func() { @@ -1055,7 +1082,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { bodyBytes, _ := io.ReadAll(resp.Body) if resp.StatusCode < 200 || resp.StatusCode >= 300 { log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode) + setOAuthStatus(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)) return } @@ -1064,7 +1091,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { fmt.Printf("Authenticated user email: %s\n", email) } else { fmt.Println("Failed to get user email from token") - oauthStatus[state] = "Failed to get user email from token" + setOAuthStatus(state, "Failed to get user email from token") } // Marshal/unmarshal oauth2.Token to generic map and enrich fields @@ -1072,7 +1099,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { jsonData, _ := json.Marshal(token) if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil { log.Errorf("Failed to unmarshal token: %v", errUnmarshal) - oauthStatus[state] = "Failed to unmarshal token" + setOAuthStatus(state, "Failed to unmarshal token") return } @@ -1098,7 +1125,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true) if errGetClient != nil { log.Errorf("failed to get authenticated client: %v", errGetClient) - oauthStatus[state] = "Failed to get authenticated client" + setOAuthStatus(state, "Failed to get authenticated client") return } fmt.Println("Authentication successful.") @@ -1108,12 +1135,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts) if errAll != nil { log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll) - oauthStatus[state] = "Failed to complete Gemini CLI onboarding" + setOAuthStatus(state, "Failed to complete Gemini CLI onboarding") return } if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errVerify) - oauthStatus[state] = "Failed to verify Cloud AI API status" + setOAuthStatus(state, "Failed to verify Cloud AI API status") return } ts.ProjectID = strings.Join(projects, ",") @@ -1121,26 +1148,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { } else { if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil { log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure) - oauthStatus[state] = "Failed to complete Gemini CLI onboarding" + setOAuthStatus(state, "Failed to complete Gemini CLI onboarding") return } if strings.TrimSpace(ts.ProjectID) == "" { log.Error("Onboarding did not return a project ID") - oauthStatus[state] = "Failed to resolve project ID" + setOAuthStatus(state, "Failed to resolve project ID") return } isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID) if errCheck != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errCheck) - oauthStatus[state] = "Failed to verify Cloud AI API status" + setOAuthStatus(state, "Failed to verify Cloud AI API status") return } ts.Checked = isChecked if !isChecked { log.Error("Cloud AI API is not enabled for the selected project") - oauthStatus[state] = "Cloud AI API not enabled" + setOAuthStatus(state, "Cloud AI API not enabled") return } } @@ -1163,15 +1190,15 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save token to file: %v", errSave) - oauthStatus[state] = "Failed to save token to file" + setOAuthStatus(state, "Failed to save token to file") return } - delete(oauthStatus, state) + deleteOAuthStatus(state) fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath) }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1235,7 +1262,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { if time.Now().After(deadline) { authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback")) log.Error(codex.GetUserFriendlyMessage(authErr)) - oauthStatus[state] = "Timeout waiting for OAuth callback" + setOAuthStatus(state, "Timeout waiting for OAuth callback") return } if data, errR := os.ReadFile(waitFile); errR == nil { @@ -1245,12 +1272,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { if errStr := m["error"]; errStr != "" { oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest) log.Error(codex.GetUserFriendlyMessage(oauthErr)) - oauthStatus[state] = "Bad Request" + setOAuthStatus(state, "Bad Request") return } if m["state"] != state { authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"])) - oauthStatus[state] = "State code error" + setOAuthStatus(state, "State code error") log.Error(codex.GetUserFriendlyMessage(authErr)) return } @@ -1281,14 +1308,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { resp, errDo := httpClient.Do(req) if errDo != nil { authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo) - oauthStatus[state] = "Failed to exchange authorization code for tokens" + setOAuthStatus(state, "Failed to exchange authorization code for tokens") log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) return } defer func() { _ = resp.Body.Close() }() respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { - oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode) + setOAuthStatus(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)) log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) return } @@ -1299,7 +1326,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { ExpiresIn int `json:"expires_in"` } if errU := json.Unmarshal(respBody, &tokenResp); errU != nil { - oauthStatus[state] = "Failed to parse token response" + setOAuthStatus(state, "Failed to parse token response") log.Errorf("failed to parse token response: %v", errU) return } @@ -1337,8 +1364,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - oauthStatus[state] = "Failed to save authentication tokens" log.Errorf("Failed to save authentication tokens: %v", errSave) + setOAuthStatus(state, "Failed to save authentication tokens") return } fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) @@ -1346,10 +1373,10 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use Codex services through this CLI") - delete(oauthStatus, state) + deleteOAuthStatus(state) }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1416,7 +1443,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { for { if time.Now().After(deadline) { log.Error("oauth flow timed out") - oauthStatus[state] = "OAuth flow timed out" + setOAuthStatus(state, "OAuth flow timed out") return } if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil { @@ -1425,18 +1452,18 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { _ = os.Remove(waitFile) if errStr := strings.TrimSpace(payload["error"]); errStr != "" { log.Errorf("Authentication failed: %s", errStr) - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") return } if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state { log.Errorf("Authentication failed: state mismatch") - oauthStatus[state] = "Authentication failed: state mismatch" + setOAuthStatus(state, "Authentication failed: state mismatch") return } authCode = strings.TrimSpace(payload["code"]) if authCode == "" { log.Error("Authentication failed: code not found") - oauthStatus[state] = "Authentication failed: code not found" + setOAuthStatus(state, "Authentication failed: code not found") return } break @@ -1455,7 +1482,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode())) if errNewRequest != nil { log.Errorf("Failed to build token request: %v", errNewRequest) - oauthStatus[state] = "Failed to build token request" + setOAuthStatus(state, "Failed to build token request") return } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -1463,7 +1490,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { resp, errDo := httpClient.Do(req) if errDo != nil { log.Errorf("Failed to execute token request: %v", errDo) - oauthStatus[state] = "Failed to exchange token" + setOAuthStatus(state, "Failed to exchange token") return } defer func() { @@ -1475,7 +1502,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { bodyBytes, _ := io.ReadAll(resp.Body) log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - oauthStatus[state] = fmt.Sprintf("Token exchange failed: %d", resp.StatusCode) + setOAuthStatus(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)) return } @@ -1487,7 +1514,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil { log.Errorf("Failed to parse token response: %v", errDecode) - oauthStatus[state] = "Failed to parse token response" + setOAuthStatus(state, "Failed to parse token response") return } @@ -1496,7 +1523,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if errInfoReq != nil { log.Errorf("Failed to build user info request: %v", errInfoReq) - oauthStatus[state] = "Failed to build user info request" + setOAuthStatus(state, "Failed to build user info request") return } infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken) @@ -1504,7 +1531,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { infoResp, errInfo := httpClient.Do(infoReq) if errInfo != nil { log.Errorf("Failed to execute user info request: %v", errInfo) - oauthStatus[state] = "Failed to execute user info request" + setOAuthStatus(state, "Failed to execute user info request") return } defer func() { @@ -1523,7 +1550,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } else { bodyBytes, _ := io.ReadAll(infoResp.Body) log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes)) - oauthStatus[state] = fmt.Sprintf("User info request failed: %d", infoResp.StatusCode) + setOAuthStatus(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)) return } } @@ -1571,11 +1598,11 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save token to file: %v", errSave) - oauthStatus[state] = "Failed to save token to file" + setOAuthStatus(state, "Failed to save token to file") return } - delete(oauthStatus, state) + deleteOAuthStatus(state) fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) if projectID != "" { fmt.Printf("Using GCP project: %s\n", projectID) @@ -1583,7 +1610,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { fmt.Println("You can now use Antigravity services through this CLI") }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1609,7 +1636,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { fmt.Println("Waiting for authentication...") tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier) if errPollForToken != nil { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Printf("Authentication failed: %v\n", errPollForToken) return } @@ -1628,16 +1655,16 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save authentication tokens: %v", errSave) - oauthStatus[state] = "Failed to save authentication tokens" + setOAuthStatus(state, "Failed to save authentication tokens") return } fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) fmt.Println("You can now use Qwen services through this CLI") - delete(oauthStatus, state) + deleteOAuthStatus(state) }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1676,7 +1703,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { var resultMap map[string]string for { if time.Now().After(deadline) { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Println("Authentication failed: timeout waiting for callback") return } @@ -1689,26 +1716,26 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { } if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Printf("Authentication failed: %s\n", errStr) return } if resultState := strings.TrimSpace(resultMap["state"]); resultState != state { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Println("Authentication failed: state mismatch") return } code := strings.TrimSpace(resultMap["code"]) if code == "" { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Println("Authentication failed: code missing") return } tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI) if errExchange != nil { - oauthStatus[state] = "Authentication failed" + setOAuthStatus(state, "Authentication failed") fmt.Printf("Authentication failed: %v\n", errExchange) return } @@ -1730,8 +1757,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - oauthStatus[state] = "Failed to save authentication tokens" log.Errorf("Failed to save authentication tokens: %v", errSave) + setOAuthStatus(state, "Failed to save authentication tokens") return } @@ -1740,10 +1767,10 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use iFlow services through this CLI") - delete(oauthStatus, state) + deleteOAuthStatus(state) }() - oauthStatus[state] = "" + setOAuthStatus(state, "") c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -2180,9 +2207,35 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec func (h *Handler) GetAuthStatus(c *gin.Context) { state := c.Query("state") - if err, ok := oauthStatus[state]; ok { - if err != "" { - c.JSON(200, gin.H{"status": "error", "error": err}) + if statusValue, ok := getOAuthStatus(state); ok { + if statusValue != "" { + // Check for device_code prefix (Kiro AWS Builder ID flow) + // Format: "device_code|verification_url|user_code" + // Using "|" as separator because URLs contain ":" + if strings.HasPrefix(statusValue, "device_code|") { + parts := strings.SplitN(statusValue, "|", 3) + if len(parts) == 3 { + c.JSON(200, gin.H{ + "status": "device_code", + "verification_url": parts[1], + "user_code": parts[2], + }) + return + } + } + // Check for auth_url prefix (Kiro social auth flow) + // Format: "auth_url|url" + // Using "|" as separator because URLs contain ":" + if strings.HasPrefix(statusValue, "auth_url|") { + authURL := strings.TrimPrefix(statusValue, "auth_url|") + c.JSON(200, gin.H{ + "status": "auth_url", + "url": authURL, + }) + return + } + // Otherwise treat as error + c.JSON(200, gin.H{"status": "error", "error": statusValue}) } else { c.JSON(200, gin.H{"status": "wait"}) return @@ -2190,5 +2243,297 @@ func (h *Handler) GetAuthStatus(c *gin.Context) { } else { c.JSON(200, gin.H{"status": "ok"}) } - delete(oauthStatus, state) + deleteOAuthStatus(state) +} + +const kiroCallbackPort = 9876 + +func (h *Handler) RequestKiroToken(c *gin.Context) { + ctx := context.Background() + + // Get the login method from query parameter (default: aws for device code flow) + method := strings.ToLower(strings.TrimSpace(c.Query("method"))) + if method == "" { + method = "aws" + } + + fmt.Println("Initializing Kiro authentication...") + + state := fmt.Sprintf("kiro-%d", time.Now().UnixNano()) + + switch method { + case "aws", "builder-id": + // AWS Builder ID uses device code flow (no callback needed) + go func() { + ssoClient := kiroauth.NewSSOOIDCClient(h.cfg) + + // Step 1: Register client + fmt.Println("Registering client...") + regResp, err := ssoClient.RegisterClient(ctx) + if err != nil { + log.Errorf("Failed to register client: %v", err) + setOAuthStatus(state, "Failed to register client") + return + } + + // Step 2: Start device authorization + fmt.Println("Starting device authorization...") + authResp, err := ssoClient.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret) + if err != nil { + log.Errorf("Failed to start device auth: %v", err) + setOAuthStatus(state, "Failed to start device authorization") + return + } + + // Store the verification URL for the frontend to display + // Using "|" as separator because URLs contain ":" + setOAuthStatus(state, "device_code|"+authResp.VerificationURIComplete+"|"+authResp.UserCode) + + // Step 3: Poll for token + fmt.Println("Waiting for authorization...") + interval := 5 * time.Second + if authResp.Interval > 0 { + interval = time.Duration(authResp.Interval) * time.Second + } + deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + setOAuthStatus(state, "Authorization cancelled") + return + case <-time.After(interval): + tokenResp, err := ssoClient.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) + if err != nil { + errStr := err.Error() + if strings.Contains(errStr, "authorization_pending") { + continue + } + if strings.Contains(errStr, "slow_down") { + interval += 5 * time.Second + continue + } + log.Errorf("Token creation failed: %v", err) + setOAuthStatus(state, "Token creation failed") + return + } + + // Success! Save the token + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken) + + idPart := kiroauth.SanitizeEmailForFilename(email) + if idPart == "" { + idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000) + } + + now := time.Now() + fileName := fmt.Sprintf("kiro-aws-%s.json", idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenResp.AccessToken, + "refresh_token": tokenResp.RefreshToken, + "expires_at": expiresAt.Format(time.RFC3339), + "auth_method": "builder-id", + "provider": "AWS", + "client_id": regResp.ClientID, + "client_secret": regResp.ClientSecret, + "email": email, + "last_refresh": now.Format(time.RFC3339), + }, + } + + savedPath, errSave := h.saveTokenRecord(ctx, record) + if errSave != nil { + log.Errorf("Failed to save authentication tokens: %v", errSave) + setOAuthStatus(state, "Failed to save authentication tokens") + return + } + + fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) + if email != "" { + fmt.Printf("Authenticated as: %s\n", email) + } + deleteOAuthStatus(state) + return + } + } + + setOAuthStatus(state, "Authorization timed out") + }() + + // Return immediately with the state for polling + c.JSON(200, gin.H{"status": "ok", "state": state, "method": "device_code"}) + + case "google", "github": + // Social auth uses protocol handler - for WEB UI we use a callback forwarder + provider := "Google" + if method == "github" { + provider = "Github" + } + + isWebUI := isWebUIRequest(c) + if isWebUI { + targetURL, errTarget := h.managementCallbackURL("/kiro/callback") + if errTarget != nil { + log.WithError(errTarget).Error("failed to compute kiro callback target") + c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) + return + } + if _, errStart := startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil { + log.WithError(errStart).Error("failed to start kiro callback forwarder") + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) + return + } + } + + go func() { + if isWebUI { + defer stopCallbackForwarder(kiroCallbackPort) + } + + socialClient := kiroauth.NewSocialAuthClient(h.cfg) + + // Generate PKCE codes + codeVerifier, codeChallenge, err := generateKiroPKCE() + if err != nil { + log.Errorf("Failed to generate PKCE: %v", err) + setOAuthStatus(state, "Failed to generate PKCE") + return + } + + // Build login URL + authURL := fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account", + "https://prod.us-east-1.auth.desktop.kiro.dev", + provider, + url.QueryEscape(kiroauth.KiroRedirectURI), + codeChallenge, + state, + ) + + // Store auth URL for frontend + // Using "|" as separator because URLs contain ":" + setOAuthStatus(state, "auth_url|"+authURL) + + // Wait for callback file + waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-kiro-%s.oauth", state)) + deadline := time.Now().Add(5 * time.Minute) + + for { + if time.Now().After(deadline) { + log.Error("oauth flow timed out") + setOAuthStatus(state, "OAuth flow timed out") + return + } + if data, errR := os.ReadFile(waitFile); errR == nil { + var m map[string]string + _ = json.Unmarshal(data, &m) + _ = os.Remove(waitFile) + if errStr := m["error"]; errStr != "" { + log.Errorf("Authentication failed: %s", errStr) + setOAuthStatus(state, "Authentication failed") + return + } + if m["state"] != state { + log.Errorf("State mismatch") + setOAuthStatus(state, "State mismatch") + return + } + code := m["code"] + if code == "" { + log.Error("No authorization code received") + setOAuthStatus(state, "No authorization code received") + return + } + + // Exchange code for tokens + tokenReq := &kiroauth.CreateTokenRequest{ + Code: code, + CodeVerifier: codeVerifier, + RedirectURI: kiroauth.KiroRedirectURI, + } + + tokenResp, errToken := socialClient.CreateToken(ctx, tokenReq) + if errToken != nil { + log.Errorf("Failed to exchange code for tokens: %v", errToken) + setOAuthStatus(state, "Failed to exchange code for tokens") + return + } + + // Save the token + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken) + + idPart := kiroauth.SanitizeEmailForFilename(email) + if idPart == "" { + idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000) + } + + now := time.Now() + fileName := fmt.Sprintf("kiro-%s-%s.json", strings.ToLower(provider), idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenResp.AccessToken, + "refresh_token": tokenResp.RefreshToken, + "profile_arn": tokenResp.ProfileArn, + "expires_at": expiresAt.Format(time.RFC3339), + "auth_method": "social", + "provider": provider, + "email": email, + "last_refresh": now.Format(time.RFC3339), + }, + } + + savedPath, errSave := h.saveTokenRecord(ctx, record) + if errSave != nil { + log.Errorf("Failed to save authentication tokens: %v", errSave) + setOAuthStatus(state, "Failed to save authentication tokens") + return + } + + fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) + if email != "" { + fmt.Printf("Authenticated as: %s\n", email) + } + deleteOAuthStatus(state) + return + } + time.Sleep(500 * time.Millisecond) + } + }() + + setOAuthStatus(state, "") + c.JSON(200, gin.H{"status": "ok", "state": state, "method": "social"}) + + default: + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid method, use 'aws', 'google', or 'github'"}) + } +} + +// generateKiroPKCE generates PKCE code verifier and challenge for Kiro OAuth. +func generateKiroPKCE() (verifier, challenge string, err error) { + b := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return "", "", fmt.Errorf("failed to generate random bytes: %w", err) + } + verifier = base64.RawURLEncoding.EncodeToString(b) + + h := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(h[:]) + + return verifier, challenge, nil } diff --git a/internal/api/handlers/management/config_basic.go b/internal/api/handlers/management/config_basic.go index ae292982..f9069198 100644 --- a/internal/api/handlers/management/config_basic.go +++ b/internal/api/handlers/management/config_basic.go @@ -19,8 +19,8 @@ import ( ) const ( - latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPI/releases/latest" - latestReleaseUserAgent = "CLIProxyAPI" + latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPIPlus/releases/latest" + latestReleaseUserAgent = "CLIProxyAPIPlus" ) func (h *Handler) GetConfig(c *gin.Context) { diff --git a/internal/api/modules/amp/proxy.go b/internal/api/modules/amp/proxy.go index 3c4ef308..435f9bf9 100644 --- a/internal/api/modules/amp/proxy.go +++ b/internal/api/modules/amp/proxy.go @@ -3,8 +3,11 @@ package amp import ( "bytes" "compress/gzip" + "context" + "errors" "fmt" "io" + "net" "net/http" "net/http/httputil" "net/url" @@ -67,7 +70,15 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi // Modify incoming responses to handle gzip without Content-Encoding // This addresses the same issue as inline handler gzip handling, but at the proxy level proxy.ModifyResponse = func(resp *http.Response) error { - // Only process successful responses + // Log upstream error responses for diagnostics (502, 503, etc.) + // These are NOT proxy connection errors - the upstream responded with an error status + if resp.StatusCode >= 500 { + log.Errorf("amp upstream responded with error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path) + } else if resp.StatusCode >= 400 { + log.Warnf("amp upstream responded with client error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path) + } + + // Only process successful responses for gzip decompression if resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil } @@ -151,9 +162,29 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi return nil } - // Error handler for proxy failures + // Error handler for proxy failures with detailed error classification for diagnostics proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) { - log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err) + // Classify the error type for better diagnostics + var errType string + if errors.Is(err, context.DeadlineExceeded) { + errType = "timeout" + } else if errors.Is(err, context.Canceled) { + errType = "canceled" + } else if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + errType = "dial_timeout" + } else if _, ok := err.(net.Error); ok { + errType = "network_error" + } else { + errType = "connection_error" + } + + // Don't log as error for context canceled - it's usually client closing connection + if errors.Is(err, context.Canceled) { + log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path) + } else { + log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err) + } + rw.Header().Set("Content-Type", "application/json") rw.WriteHeader(http.StatusBadGateway) _, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`)) diff --git a/internal/api/modules/amp/response_rewriter.go b/internal/api/modules/amp/response_rewriter.go index de6ba137..e6f20c57 100644 --- a/internal/api/modules/amp/response_rewriter.go +++ b/internal/api/modules/amp/response_rewriter.go @@ -29,15 +29,71 @@ func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRe } } +const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap + +func looksLikeSSEChunk(data []byte) bool { + // Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered. + // Heuristics are intentionally simple and cheap. + return bytes.Contains(data, []byte("data:")) || + bytes.Contains(data, []byte("event:")) || + bytes.Contains(data, []byte("message_start")) || + bytes.Contains(data, []byte("message_delta")) || + bytes.Contains(data, []byte("content_block_start")) || + bytes.Contains(data, []byte("content_block_delta")) || + bytes.Contains(data, []byte("content_block_stop")) || + bytes.Contains(data, []byte("\n\n")) +} + +func (rw *ResponseRewriter) enableStreaming(reason string) error { + if rw.isStreaming { + return nil + } + rw.isStreaming = true + + // Flush any previously buffered data to avoid reordering or data loss. + if rw.body != nil && rw.body.Len() > 0 { + buf := rw.body.Bytes() + // Copy before Reset() to keep bytes stable. + toFlush := make([]byte, len(buf)) + copy(toFlush, buf) + rw.body.Reset() + + if _, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(toFlush)); err != nil { + return err + } + if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } + } + + log.Debugf("amp response rewriter: switched to streaming (%s)", reason) + return nil +} + // Write intercepts response writes and buffers them for model name replacement func (rw *ResponseRewriter) Write(data []byte) (int, error) { - // Detect streaming on first write - if rw.body.Len() == 0 && !rw.isStreaming { + // Detect streaming on first write (header-based) + if !rw.isStreaming && rw.body.Len() == 0 { contentType := rw.Header().Get("Content-Type") rw.isStreaming = strings.Contains(contentType, "text/event-stream") || strings.Contains(contentType, "stream") } + if !rw.isStreaming { + // Content-based fallback: detect SSE-like chunks even if Content-Type is missing/wrong. + if looksLikeSSEChunk(data) { + if err := rw.enableStreaming("sse heuristic"); err != nil { + return 0, err + } + } else if rw.body.Len()+len(data) > maxBufferedResponseBytes { + // Safety cap: avoid unbounded buffering on large responses. + log.Warnf("amp response rewriter: buffer exceeded %d bytes, switching to streaming", maxBufferedResponseBytes) + if err := rw.enableStreaming("buffer limit"); err != nil { + return 0, err + } + } + } + if rw.isStreaming { n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data)) if err == nil { diff --git a/internal/api/server.go b/internal/api/server.go index af28e6ad..c43f9971 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -349,6 +349,12 @@ func (s *Server) setupRoutes() { }, }) }) + + // Event logging endpoint - handles Claude Code telemetry requests + // Returns 200 OK to prevent 404 errors in logs + s.engine.POST("/api/event_logging/batch", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + }) s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler) // OAuth callback endpoints (reuse main server port) @@ -415,6 +421,18 @@ func (s *Server) setupRoutes() { c.String(http.StatusOK, oauthCallbackSuccessHTML) }) + s.engine.GET("/kiro/callback", func(c *gin.Context) { + code := c.Query("code") + state := c.Query("state") + errStr := c.Query("error") + if state != "" { + file := fmt.Sprintf("%s/.oauth-kiro-%s.oauth", s.cfg.AuthDir, state) + _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + } + c.Header("Content-Type", "text/html; charset=utf-8") + c.String(http.StatusOK, oauthCallbackSuccessHTML) + }) + // Management routes are registered lazily by registerManagementRoutes when a secret is configured. } @@ -581,6 +599,7 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken) mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken) mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken) + mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken) mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus) } } @@ -923,7 +942,7 @@ func (s *Server) UpdateClients(cfg *config.Config) { for _, p := range cfg.OpenAICompatibility { providerNames = append(providerNames, p.Name) } - s.handlers.OpenAICompatProviders = providerNames + s.handlers.SetOpenAICompatProviders(providerNames) s.handlers.UpdateClients(&cfg.SDKConfig) diff --git a/internal/auth/claude/oauth_server.go b/internal/auth/claude/oauth_server.go index a6ebe2f7..49b04794 100644 --- a/internal/auth/claude/oauth_server.go +++ b/internal/auth/claude/oauth_server.go @@ -242,6 +242,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { platformURL = "https://console.anthropic.com/" } + // Validate platformURL to prevent XSS - only allow http/https URLs + if !isValidURL(platformURL) { + platformURL = "https://console.anthropic.com/" + } + // Generate success page HTML with dynamic content successHTML := s.generateSuccessHTML(setupRequired, platformURL) @@ -251,6 +256,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { } } +// isValidURL checks if the URL is a valid http/https URL to prevent XSS +func isValidURL(urlStr string) bool { + urlStr = strings.TrimSpace(urlStr) + return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://") +} + // generateSuccessHTML creates the HTML content for the success page. // It customizes the page based on whether additional setup is required // and includes a link to the platform. diff --git a/internal/auth/codex/oauth_server.go b/internal/auth/codex/oauth_server.go index 9c6a6c5b..58b5394e 100644 --- a/internal/auth/codex/oauth_server.go +++ b/internal/auth/codex/oauth_server.go @@ -239,6 +239,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { platformURL = "https://platform.openai.com" } + // Validate platformURL to prevent XSS - only allow http/https URLs + if !isValidURL(platformURL) { + platformURL = "https://platform.openai.com" + } + // Generate success page HTML with dynamic content successHTML := s.generateSuccessHTML(setupRequired, platformURL) @@ -248,6 +253,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { } } +// isValidURL checks if the URL is a valid http/https URL to prevent XSS +func isValidURL(urlStr string) bool { + urlStr = strings.TrimSpace(urlStr) + return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://") +} + // generateSuccessHTML creates the HTML content for the success page. // It customizes the page based on whether additional setup is required // and includes a link to the platform. diff --git a/internal/auth/copilot/copilot_auth.go b/internal/auth/copilot/copilot_auth.go new file mode 100644 index 00000000..c40e7082 --- /dev/null +++ b/internal/auth/copilot/copilot_auth.go @@ -0,0 +1,225 @@ +// Package copilot provides authentication and token management for GitHub Copilot API. +// It handles the OAuth2 device flow for secure authentication with the Copilot API. +package copilot + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // copilotAPITokenURL is the endpoint for getting Copilot API tokens from GitHub token. + copilotAPITokenURL = "https://api.github.com/copilot_internal/v2/token" + // copilotAPIEndpoint is the base URL for making API requests. + copilotAPIEndpoint = "https://api.githubcopilot.com" + + // Common HTTP header values for Copilot API requests. + copilotUserAgent = "GithubCopilot/1.0" + copilotEditorVersion = "vscode/1.100.0" + copilotPluginVersion = "copilot/1.300.0" + copilotIntegrationID = "vscode-chat" + copilotOpenAIIntent = "conversation-panel" +) + +// CopilotAPIToken represents the Copilot API token response. +type CopilotAPIToken struct { + // Token is the JWT token for authenticating with the Copilot API. + Token string `json:"token"` + // ExpiresAt is the Unix timestamp when the token expires. + ExpiresAt int64 `json:"expires_at"` + // Endpoints contains the available API endpoints. + Endpoints struct { + API string `json:"api"` + Proxy string `json:"proxy"` + OriginTracker string `json:"origin-tracker"` + Telemetry string `json:"telemetry"` + } `json:"endpoints,omitempty"` + // ErrorDetails contains error information if the request failed. + ErrorDetails *struct { + URL string `json:"url"` + Message string `json:"message"` + DocumentationURL string `json:"documentation_url"` + } `json:"error_details,omitempty"` +} + +// CopilotAuth handles GitHub Copilot authentication flow. +// It provides methods for device flow authentication and token management. +type CopilotAuth struct { + httpClient *http.Client + deviceClient *DeviceFlowClient + cfg *config.Config +} + +// NewCopilotAuth creates a new CopilotAuth service instance. +// It initializes an HTTP client with proxy settings from the provided configuration. +func NewCopilotAuth(cfg *config.Config) *CopilotAuth { + return &CopilotAuth{ + httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}), + deviceClient: NewDeviceFlowClient(cfg), + cfg: cfg, + } +} + +// StartDeviceFlow initiates the device flow authentication. +// Returns the device code response containing the user code and verification URI. +func (c *CopilotAuth) StartDeviceFlow(ctx context.Context) (*DeviceCodeResponse, error) { + return c.deviceClient.RequestDeviceCode(ctx) +} + +// WaitForAuthorization polls for user authorization and returns the auth bundle. +func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *DeviceCodeResponse) (*CopilotAuthBundle, error) { + tokenData, err := c.deviceClient.PollForToken(ctx, deviceCode) + if err != nil { + return nil, err + } + + // Fetch the GitHub username + username, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken) + if err != nil { + log.Warnf("copilot: failed to fetch user info: %v", err) + username = "unknown" + } + + return &CopilotAuthBundle{ + TokenData: tokenData, + Username: username, + }, nil +} + +// GetCopilotAPIToken exchanges a GitHub access token for a Copilot API token. +// This token is used to make authenticated requests to the Copilot API. +func (c *CopilotAuth) GetCopilotAPIToken(ctx context.Context, githubAccessToken string) (*CopilotAPIToken, error) { + if githubAccessToken == "" { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("github access token is empty")) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotAPITokenURL, nil) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + + req.Header.Set("Authorization", "token "+githubAccessToken) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", copilotUserAgent) + req.Header.Set("Editor-Version", copilotEditorVersion) + req.Header.Set("Editor-Plugin-Version", copilotPluginVersion) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("copilot api token: close body error: %v", errClose) + } + }() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + + if !isHTTPSuccess(resp.StatusCode) { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, + fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) + } + + var apiToken CopilotAPIToken + if err = json.Unmarshal(bodyBytes, &apiToken); err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + + if apiToken.Token == "" { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("empty copilot api token")) + } + + return &apiToken, nil +} + +// ValidateToken checks if a GitHub access token is valid by attempting to fetch user info. +func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bool, string, error) { + if accessToken == "" { + return false, "", nil + } + + username, err := c.deviceClient.FetchUserInfo(ctx, accessToken) + if err != nil { + return false, "", err + } + + return true, username, nil +} + +// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle. +func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotTokenStorage { + return &CopilotTokenStorage{ + AccessToken: bundle.TokenData.AccessToken, + TokenType: bundle.TokenData.TokenType, + Scope: bundle.TokenData.Scope, + Username: bundle.Username, + Type: "github-copilot", + } +} + +// LoadAndValidateToken loads a token from storage and validates it. +// Returns the storage if valid, or an error if the token is invalid or expired. +func (c *CopilotAuth) LoadAndValidateToken(ctx context.Context, storage *CopilotTokenStorage) (bool, error) { + if storage == nil || storage.AccessToken == "" { + return false, fmt.Errorf("no token available") + } + + // Check if we can still use the GitHub token to get a Copilot API token + apiToken, err := c.GetCopilotAPIToken(ctx, storage.AccessToken) + if err != nil { + return false, err + } + + // Check if the API token is expired + if apiToken.ExpiresAt > 0 && time.Now().Unix() >= apiToken.ExpiresAt { + return false, fmt.Errorf("copilot api token expired") + } + + return true, nil +} + +// GetAPIEndpoint returns the Copilot API endpoint URL. +func (c *CopilotAuth) GetAPIEndpoint() string { + return copilotAPIEndpoint +} + +// MakeAuthenticatedRequest creates an authenticated HTTP request to the Copilot API. +func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url string, body io.Reader, apiToken *CopilotAPIToken) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+apiToken.Token) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", copilotUserAgent) + req.Header.Set("Editor-Version", copilotEditorVersion) + req.Header.Set("Editor-Plugin-Version", copilotPluginVersion) + req.Header.Set("Openai-Intent", copilotOpenAIIntent) + req.Header.Set("Copilot-Integration-Id", copilotIntegrationID) + + return req, nil +} + +// buildChatCompletionURL builds the URL for chat completions API. +func buildChatCompletionURL() string { + return copilotAPIEndpoint + "/chat/completions" +} + +// isHTTPSuccess checks if the status code indicates success (2xx). +func isHTTPSuccess(statusCode int) bool { + return statusCode >= 200 && statusCode < 300 +} diff --git a/internal/auth/copilot/errors.go b/internal/auth/copilot/errors.go new file mode 100644 index 00000000..a82dd8ec --- /dev/null +++ b/internal/auth/copilot/errors.go @@ -0,0 +1,187 @@ +package copilot + +import ( + "errors" + "fmt" + "net/http" +) + +// OAuthError represents an OAuth-specific error. +type OAuthError struct { + // Code is the OAuth error code. + Code string `json:"error"` + // Description is a human-readable description of the error. + Description string `json:"error_description,omitempty"` + // URI is a URI identifying a human-readable web page with information about the error. + URI string `json:"error_uri,omitempty"` + // StatusCode is the HTTP status code associated with the error. + StatusCode int `json:"-"` +} + +// Error returns a string representation of the OAuth error. +func (e *OAuthError) Error() string { + if e.Description != "" { + return fmt.Sprintf("OAuth error %s: %s", e.Code, e.Description) + } + return fmt.Sprintf("OAuth error: %s", e.Code) +} + +// NewOAuthError creates a new OAuth error with the specified code, description, and status code. +func NewOAuthError(code, description string, statusCode int) *OAuthError { + return &OAuthError{ + Code: code, + Description: description, + StatusCode: statusCode, + } +} + +// AuthenticationError represents authentication-related errors. +type AuthenticationError struct { + // Type is the type of authentication error. + Type string `json:"type"` + // Message is a human-readable message describing the error. + Message string `json:"message"` + // Code is the HTTP status code associated with the error. + Code int `json:"code"` + // Cause is the underlying error that caused this authentication error. + Cause error `json:"-"` +} + +// Error returns a string representation of the authentication error. +func (e *AuthenticationError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("%s: %s (caused by: %v)", e.Type, e.Message, e.Cause) + } + return fmt.Sprintf("%s: %s", e.Type, e.Message) +} + +// Unwrap returns the underlying cause of the error. +func (e *AuthenticationError) Unwrap() error { + return e.Cause +} + +// Common authentication error types for GitHub Copilot device flow. +var ( + // ErrDeviceCodeFailed represents an error when requesting the device code fails. + ErrDeviceCodeFailed = &AuthenticationError{ + Type: "device_code_failed", + Message: "Failed to request device code from GitHub", + Code: http.StatusBadRequest, + } + + // ErrDeviceCodeExpired represents an error when the device code has expired. + ErrDeviceCodeExpired = &AuthenticationError{ + Type: "device_code_expired", + Message: "Device code has expired. Please try again.", + Code: http.StatusGone, + } + + // ErrAuthorizationPending represents a pending authorization state (not an error, used for polling). + ErrAuthorizationPending = &AuthenticationError{ + Type: "authorization_pending", + Message: "Authorization is pending. Waiting for user to authorize.", + Code: http.StatusAccepted, + } + + // ErrSlowDown represents a request to slow down polling. + ErrSlowDown = &AuthenticationError{ + Type: "slow_down", + Message: "Polling too frequently. Slowing down.", + Code: http.StatusTooManyRequests, + } + + // ErrAccessDenied represents an error when the user denies authorization. + ErrAccessDenied = &AuthenticationError{ + Type: "access_denied", + Message: "User denied authorization", + Code: http.StatusForbidden, + } + + // ErrTokenExchangeFailed represents an error when token exchange fails. + ErrTokenExchangeFailed = &AuthenticationError{ + Type: "token_exchange_failed", + Message: "Failed to exchange device code for access token", + Code: http.StatusBadRequest, + } + + // ErrPollingTimeout represents an error when polling times out. + ErrPollingTimeout = &AuthenticationError{ + Type: "polling_timeout", + Message: "Timeout waiting for user authorization", + Code: http.StatusRequestTimeout, + } + + // ErrUserInfoFailed represents an error when fetching user info fails. + ErrUserInfoFailed = &AuthenticationError{ + Type: "user_info_failed", + Message: "Failed to fetch GitHub user information", + Code: http.StatusBadRequest, + } +) + +// NewAuthenticationError creates a new authentication error with a cause based on a base error. +func NewAuthenticationError(baseErr *AuthenticationError, cause error) *AuthenticationError { + return &AuthenticationError{ + Type: baseErr.Type, + Message: baseErr.Message, + Code: baseErr.Code, + Cause: cause, + } +} + +// IsAuthenticationError checks if an error is an authentication error. +func IsAuthenticationError(err error) bool { + var authenticationError *AuthenticationError + ok := errors.As(err, &authenticationError) + return ok +} + +// IsOAuthError checks if an error is an OAuth error. +func IsOAuthError(err error) bool { + var oAuthError *OAuthError + ok := errors.As(err, &oAuthError) + return ok +} + +// GetUserFriendlyMessage returns a user-friendly error message based on the error type. +func GetUserFriendlyMessage(err error) string { + var authErr *AuthenticationError + if errors.As(err, &authErr) { + switch authErr.Type { + case "device_code_failed": + return "Failed to start GitHub authentication. Please check your network connection and try again." + case "device_code_expired": + return "The authentication code has expired. Please try again." + case "authorization_pending": + return "Waiting for you to authorize the application on GitHub." + case "slow_down": + return "Please wait a moment before trying again." + case "access_denied": + return "Authentication was cancelled or denied." + case "token_exchange_failed": + return "Failed to complete authentication. Please try again." + case "polling_timeout": + return "Authentication timed out. Please try again." + case "user_info_failed": + return "Failed to get your GitHub account information. Please try again." + default: + return "Authentication failed. Please try again." + } + } + + var oauthErr *OAuthError + if errors.As(err, &oauthErr) { + switch oauthErr.Code { + case "access_denied": + return "Authentication was cancelled or denied." + case "invalid_request": + return "Invalid authentication request. Please try again." + case "server_error": + return "GitHub server error. Please try again later." + default: + return fmt.Sprintf("Authentication failed: %s", oauthErr.Description) + } + } + + return "An unexpected error occurred. Please try again." +} diff --git a/internal/auth/copilot/oauth.go b/internal/auth/copilot/oauth.go new file mode 100644 index 00000000..d3f46aaa --- /dev/null +++ b/internal/auth/copilot/oauth.go @@ -0,0 +1,255 @@ +package copilot + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // copilotClientID is GitHub's Copilot CLI OAuth client ID. + copilotClientID = "Iv1.b507a08c87ecfe98" + // copilotDeviceCodeURL is the endpoint for requesting device codes. + copilotDeviceCodeURL = "https://github.com/login/device/code" + // copilotTokenURL is the endpoint for exchanging device codes for tokens. + copilotTokenURL = "https://github.com/login/oauth/access_token" + // copilotUserInfoURL is the endpoint for fetching GitHub user information. + copilotUserInfoURL = "https://api.github.com/user" + // defaultPollInterval is the default interval for polling token endpoint. + defaultPollInterval = 5 * time.Second + // maxPollDuration is the maximum time to wait for user authorization. + maxPollDuration = 15 * time.Minute +) + +// DeviceFlowClient handles the OAuth2 device flow for GitHub Copilot. +type DeviceFlowClient struct { + httpClient *http.Client + cfg *config.Config +} + +// NewDeviceFlowClient creates a new device flow client. +func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &DeviceFlowClient{ + httpClient: client, + cfg: cfg, + } +} + +// RequestDeviceCode initiates the device flow by requesting a device code from GitHub. +func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) { + data := url.Values{} + data.Set("client_id", copilotClientID) + data.Set("scope", "user:email") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, NewAuthenticationError(ErrDeviceCodeFailed, err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, NewAuthenticationError(ErrDeviceCodeFailed, err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("copilot device code: close body error: %v", errClose) + } + }() + + if !isHTTPSuccess(resp.StatusCode) { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, NewAuthenticationError(ErrDeviceCodeFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) + } + + var deviceCode DeviceCodeResponse + if err = json.NewDecoder(resp.Body).Decode(&deviceCode); err != nil { + return nil, NewAuthenticationError(ErrDeviceCodeFailed, err) + } + + return &deviceCode, nil +} + +// PollForToken polls the token endpoint until the user authorizes or the device code expires. +func (c *DeviceFlowClient) PollForToken(ctx context.Context, deviceCode *DeviceCodeResponse) (*CopilotTokenData, error) { + if deviceCode == nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("device code is nil")) + } + + interval := time.Duration(deviceCode.Interval) * time.Second + if interval < defaultPollInterval { + interval = defaultPollInterval + } + + deadline := time.Now().Add(maxPollDuration) + if deviceCode.ExpiresIn > 0 { + codeDeadline := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second) + if codeDeadline.Before(deadline) { + deadline = codeDeadline + } + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil, NewAuthenticationError(ErrPollingTimeout, ctx.Err()) + case <-ticker.C: + if time.Now().After(deadline) { + return nil, ErrPollingTimeout + } + + token, err := c.exchangeDeviceCode(ctx, deviceCode.DeviceCode) + if err != nil { + var authErr *AuthenticationError + if errors.As(err, &authErr) { + switch authErr.Type { + case ErrAuthorizationPending.Type: + // Continue polling + continue + case ErrSlowDown.Type: + // Increase interval and continue + interval += 5 * time.Second + ticker.Reset(interval) + continue + case ErrDeviceCodeExpired.Type: + return nil, err + case ErrAccessDenied.Type: + return nil, err + } + } + return nil, err + } + return token, nil + } + } +} + +// exchangeDeviceCode attempts to exchange the device code for an access token. +func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode string) (*CopilotTokenData, error) { + data := url.Values{} + data.Set("client_id", copilotClientID) + data.Set("device_code", deviceCode) + data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotTokenURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("copilot token exchange: close body error: %v", errClose) + } + }() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + + // GitHub returns 200 for both success and error cases in device flow + // Check for OAuth error response first + var oauthResp struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + } + + if err = json.Unmarshal(bodyBytes, &oauthResp); err != nil { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, err) + } + + if oauthResp.Error != "" { + switch oauthResp.Error { + case "authorization_pending": + return nil, ErrAuthorizationPending + case "slow_down": + return nil, ErrSlowDown + case "expired_token": + return nil, ErrDeviceCodeExpired + case "access_denied": + return nil, ErrAccessDenied + default: + return nil, NewOAuthError(oauthResp.Error, oauthResp.ErrorDescription, resp.StatusCode) + } + } + + if oauthResp.AccessToken == "" { + return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("empty access token")) + } + + return &CopilotTokenData{ + AccessToken: oauthResp.AccessToken, + TokenType: oauthResp.TokenType, + Scope: oauthResp.Scope, + }, nil +} + +// FetchUserInfo retrieves the GitHub username for the authenticated user. +func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (string, error) { + if accessToken == "" { + return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty")) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil) + if err != nil { + return "", NewAuthenticationError(ErrUserInfoFailed, err) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "CLIProxyAPI") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", NewAuthenticationError(ErrUserInfoFailed, err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("copilot user info: close body error: %v", errClose) + } + }() + + if !isHTTPSuccess(resp.StatusCode) { + bodyBytes, _ := io.ReadAll(resp.Body) + return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes))) + } + + var userInfo struct { + Login string `json:"login"` + } + if err = json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { + return "", NewAuthenticationError(ErrUserInfoFailed, err) + } + + if userInfo.Login == "" { + return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username")) + } + + return userInfo.Login, nil +} diff --git a/internal/auth/copilot/token.go b/internal/auth/copilot/token.go new file mode 100644 index 00000000..4e5eed6c --- /dev/null +++ b/internal/auth/copilot/token.go @@ -0,0 +1,93 @@ +// Package copilot provides authentication and token management functionality +// for GitHub Copilot AI services. It handles OAuth2 device flow token storage, +// serialization, and retrieval for maintaining authenticated sessions with the Copilot API. +package copilot + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" +) + +// CopilotTokenStorage stores OAuth2 token information for GitHub Copilot API authentication. +// It maintains compatibility with the existing auth system while adding Copilot-specific fields +// for managing access tokens and user account information. +type CopilotTokenStorage struct { + // AccessToken is the OAuth2 access token used for authenticating API requests. + AccessToken string `json:"access_token"` + // TokenType is the type of token, typically "bearer". + TokenType string `json:"token_type"` + // Scope is the OAuth2 scope granted to the token. + Scope string `json:"scope"` + // ExpiresAt is the timestamp when the access token expires (if provided). + ExpiresAt string `json:"expires_at,omitempty"` + // Username is the GitHub username associated with this token. + Username string `json:"username"` + // Type indicates the authentication provider type, always "github-copilot" for this storage. + Type string `json:"type"` +} + +// CopilotTokenData holds the raw OAuth token response from GitHub. +type CopilotTokenData struct { + // AccessToken is the OAuth2 access token. + AccessToken string `json:"access_token"` + // TokenType is the type of token, typically "bearer". + TokenType string `json:"token_type"` + // Scope is the OAuth2 scope granted to the token. + Scope string `json:"scope"` +} + +// CopilotAuthBundle bundles authentication data for storage. +type CopilotAuthBundle struct { + // TokenData contains the OAuth token information. + TokenData *CopilotTokenData + // Username is the GitHub username. + Username string +} + +// DeviceCodeResponse represents GitHub's device code response. +type DeviceCodeResponse struct { + // DeviceCode is the device verification code. + DeviceCode string `json:"device_code"` + // UserCode is the code the user must enter at the verification URI. + UserCode string `json:"user_code"` + // VerificationURI is the URL where the user should enter the code. + VerificationURI string `json:"verification_uri"` + // ExpiresIn is the number of seconds until the device code expires. + ExpiresIn int `json:"expires_in"` + // Interval is the minimum number of seconds to wait between polling requests. + Interval int `json:"interval"` +} + +// SaveTokenToFile serializes the Copilot token storage to a JSON file. +// This method creates the necessary directory structure and writes the token +// data in JSON format to the specified file path for persistent storage. +// +// Parameters: +// - authFilePath: The full path where the token file should be saved +// +// Returns: +// - error: An error if the operation fails, nil otherwise +func (ts *CopilotTokenStorage) SaveTokenToFile(authFilePath string) error { + misc.LogSavingCredentials(authFilePath) + ts.Type = "github-copilot" + if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { + return fmt.Errorf("failed to create directory: %v", err) + } + + f, err := os.Create(authFilePath) + if err != nil { + return fmt.Errorf("failed to create token file: %w", err) + } + defer func() { + _ = f.Close() + }() + + if err = json.NewEncoder(f).Encode(ts); err != nil { + return fmt.Errorf("failed to write token to file: %w", err) + } + return nil +} diff --git a/internal/auth/iflow/iflow_auth.go b/internal/auth/iflow/iflow_auth.go index fa9f38c3..279d7339 100644 --- a/internal/auth/iflow/iflow_auth.go +++ b/internal/auth/iflow/iflow_auth.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "net/url" + "os" "strings" "time" @@ -28,10 +29,21 @@ const ( iFlowAPIKeyEndpoint = "https://platform.iflow.cn/api/openapi/apikey" // Client credentials provided by iFlow for the Code Assist integration. - iFlowOAuthClientID = "10009311001" - iFlowOAuthClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW" + iFlowOAuthClientID = "10009311001" + // Default client secret (can be overridden via IFLOW_CLIENT_SECRET env var) + defaultIFlowClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW" ) +// getIFlowClientSecret returns the iFlow OAuth client secret. +// It first checks the IFLOW_CLIENT_SECRET environment variable, +// falling back to the default value if not set. +func getIFlowClientSecret() string { + if secret := os.Getenv("IFLOW_CLIENT_SECRET"); secret != "" { + return secret + } + return defaultIFlowClientSecret +} + // DefaultAPIBaseURL is the canonical chat completions endpoint. const DefaultAPIBaseURL = "https://apis.iflow.cn/v1" @@ -72,7 +84,7 @@ func (ia *IFlowAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectUR form.Set("code", code) form.Set("redirect_uri", redirectURI) form.Set("client_id", iFlowOAuthClientID) - form.Set("client_secret", iFlowOAuthClientSecret) + form.Set("client_secret", getIFlowClientSecret()) req, err := ia.newTokenRequest(ctx, form) if err != nil { @@ -88,7 +100,7 @@ func (ia *IFlowAuth) RefreshTokens(ctx context.Context, refreshToken string) (*I form.Set("grant_type", "refresh_token") form.Set("refresh_token", refreshToken) form.Set("client_id", iFlowOAuthClientID) - form.Set("client_secret", iFlowOAuthClientSecret) + form.Set("client_secret", getIFlowClientSecret()) req, err := ia.newTokenRequest(ctx, form) if err != nil { @@ -104,7 +116,7 @@ func (ia *IFlowAuth) newTokenRequest(ctx context.Context, form url.Values) (*htt return nil, fmt.Errorf("iflow token: create request failed: %w", err) } - basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + iFlowOAuthClientSecret)) + basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + getIFlowClientSecret())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Basic "+basic) diff --git a/internal/auth/kiro/aws.go b/internal/auth/kiro/aws.go new file mode 100644 index 00000000..9be025c2 --- /dev/null +++ b/internal/auth/kiro/aws.go @@ -0,0 +1,301 @@ +// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API. +// It includes interfaces and implementations for token storage and authentication methods. +package kiro + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" +) + +// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow +type PKCECodes struct { + // CodeVerifier is the cryptographically random string used to correlate + // the authorization request to the token request + CodeVerifier string `json:"code_verifier"` + // CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded + CodeChallenge string `json:"code_challenge"` +} + +// KiroTokenData holds OAuth token information from AWS CodeWhisperer (Kiro) +type KiroTokenData struct { + // AccessToken is the OAuth2 access token for API access + AccessToken string `json:"accessToken"` + // RefreshToken is used to obtain new access tokens + RefreshToken string `json:"refreshToken"` + // ProfileArn is the AWS CodeWhisperer profile ARN + ProfileArn string `json:"profileArn"` + // ExpiresAt is the timestamp when the token expires + ExpiresAt string `json:"expiresAt"` + // AuthMethod indicates the authentication method used (e.g., "builder-id", "social") + AuthMethod string `json:"authMethod"` + // Provider indicates the OAuth provider (e.g., "AWS", "Google") + Provider string `json:"provider"` + // ClientID is the OIDC client ID (needed for token refresh) + ClientID string `json:"clientId,omitempty"` + // ClientSecret is the OIDC client secret (needed for token refresh) + ClientSecret string `json:"clientSecret,omitempty"` + // Email is the user's email address (used for file naming) + Email string `json:"email,omitempty"` +} + +// KiroAuthBundle aggregates authentication data after OAuth flow completion +type KiroAuthBundle struct { + // TokenData contains the OAuth tokens from the authentication flow + TokenData KiroTokenData `json:"token_data"` + // LastRefresh is the timestamp of the last token refresh + LastRefresh string `json:"last_refresh"` +} + +// KiroUsageInfo represents usage information from CodeWhisperer API +type KiroUsageInfo struct { + // SubscriptionTitle is the subscription plan name (e.g., "KIRO FREE") + SubscriptionTitle string `json:"subscription_title"` + // CurrentUsage is the current credit usage + CurrentUsage float64 `json:"current_usage"` + // UsageLimit is the maximum credit limit + UsageLimit float64 `json:"usage_limit"` + // NextReset is the timestamp of the next usage reset + NextReset string `json:"next_reset"` +} + +// KiroModel represents a model available through the CodeWhisperer API +type KiroModel struct { + // ModelID is the unique identifier for the model + ModelID string `json:"modelId"` + // ModelName is the human-readable name + ModelName string `json:"modelName"` + // Description is the model description + Description string `json:"description"` + // RateMultiplier is the credit multiplier for this model + RateMultiplier float64 `json:"rateMultiplier"` + // RateUnit is the unit for rate calculation (e.g., "credit") + RateUnit string `json:"rateUnit"` + // MaxInputTokens is the maximum input token limit + MaxInputTokens int `json:"maxInputTokens,omitempty"` +} + +// KiroIDETokenFile is the default path to Kiro IDE's token file +const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json" + +// LoadKiroIDEToken loads token data from Kiro IDE's token file. +func LoadKiroIDEToken() (*KiroTokenData, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + + tokenPath := filepath.Join(homeDir, KiroIDETokenFile) + data, err := os.ReadFile(tokenPath) + if err != nil { + return nil, fmt.Errorf("failed to read Kiro IDE token file (%s): %w", tokenPath, err) + } + + var token KiroTokenData + if err := json.Unmarshal(data, &token); err != nil { + return nil, fmt.Errorf("failed to parse Kiro IDE token: %w", err) + } + + if token.AccessToken == "" { + return nil, fmt.Errorf("access token is empty in Kiro IDE token file") + } + + return &token, nil +} + +// LoadKiroTokenFromPath loads token data from a custom path. +// This supports multiple accounts by allowing different token files. +func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) { + // Expand ~ to home directory + if len(tokenPath) > 0 && tokenPath[0] == '~' { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + tokenPath = filepath.Join(homeDir, tokenPath[1:]) + } + + data, err := os.ReadFile(tokenPath) + if err != nil { + return nil, fmt.Errorf("failed to read token file (%s): %w", tokenPath, err) + } + + var token KiroTokenData + if err := json.Unmarshal(data, &token); err != nil { + return nil, fmt.Errorf("failed to parse token file: %w", err) + } + + if token.AccessToken == "" { + return nil, fmt.Errorf("access token is empty in token file") + } + + return &token, nil +} + +// ListKiroTokenFiles lists all Kiro token files in the cache directory. +// This supports multiple accounts by finding all token files. +func ListKiroTokenFiles() ([]string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + + cacheDir := filepath.Join(homeDir, ".aws", "sso", "cache") + + // Check if directory exists + if _, err := os.Stat(cacheDir); os.IsNotExist(err) { + return nil, nil // No token files + } + + entries, err := os.ReadDir(cacheDir) + if err != nil { + return nil, fmt.Errorf("failed to read cache directory: %w", err) + } + + var tokenFiles []string + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + // Look for kiro token files only (avoid matching unrelated AWS SSO cache files) + if strings.HasSuffix(name, ".json") && strings.HasPrefix(name, "kiro") { + tokenFiles = append(tokenFiles, filepath.Join(cacheDir, name)) + } + } + + return tokenFiles, nil +} + +// LoadAllKiroTokens loads all Kiro tokens from the cache directory. +// This supports multiple accounts. +func LoadAllKiroTokens() ([]*KiroTokenData, error) { + files, err := ListKiroTokenFiles() + if err != nil { + return nil, err + } + + var tokens []*KiroTokenData + for _, file := range files { + token, err := LoadKiroTokenFromPath(file) + if err != nil { + // Skip invalid token files + continue + } + tokens = append(tokens, token) + } + + return tokens, nil +} + +// JWTClaims represents the claims we care about from a JWT token. +// JWT tokens from Kiro/AWS contain user information in the payload. +type JWTClaims struct { + Email string `json:"email,omitempty"` + Sub string `json:"sub,omitempty"` + PreferredUser string `json:"preferred_username,omitempty"` + Name string `json:"name,omitempty"` + Iss string `json:"iss,omitempty"` +} + +// ExtractEmailFromJWT extracts the user's email from a JWT access token. +// JWT tokens typically have format: header.payload.signature +// The payload is base64url-encoded JSON containing user claims. +func ExtractEmailFromJWT(accessToken string) string { + if accessToken == "" { + return "" + } + + // JWT format: header.payload.signature + parts := strings.Split(accessToken, ".") + if len(parts) != 3 { + return "" + } + + // Decode the payload (second part) + payload := parts[1] + + // Add padding if needed (base64url requires padding) + switch len(payload) % 4 { + case 2: + payload += "==" + case 3: + payload += "=" + } + + decoded, err := base64.URLEncoding.DecodeString(payload) + if err != nil { + // Try RawURLEncoding (no padding) + decoded, err = base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "" + } + } + + var claims JWTClaims + if err := json.Unmarshal(decoded, &claims); err != nil { + return "" + } + + // Return email if available + if claims.Email != "" { + return claims.Email + } + + // Fallback to preferred_username (some providers use this) + if claims.PreferredUser != "" && strings.Contains(claims.PreferredUser, "@") { + return claims.PreferredUser + } + + // Fallback to sub if it looks like an email + if claims.Sub != "" && strings.Contains(claims.Sub, "@") { + return claims.Sub + } + + return "" +} + +// SanitizeEmailForFilename sanitizes an email address for use in a filename. +// Replaces special characters with underscores and prevents path traversal attacks. +// Also handles URL-encoded characters to prevent encoded path traversal attempts. +func SanitizeEmailForFilename(email string) string { + if email == "" { + return "" + } + + result := email + + // First, handle URL-encoded path traversal attempts (%2F, %2E, %5C, etc.) + // This prevents encoded characters from bypassing the sanitization. + // Note: We replace % last to catch any remaining encodings including double-encoding (%252F) + result = strings.ReplaceAll(result, "%2F", "_") // / + result = strings.ReplaceAll(result, "%2f", "_") + result = strings.ReplaceAll(result, "%5C", "_") // \ + result = strings.ReplaceAll(result, "%5c", "_") + result = strings.ReplaceAll(result, "%2E", "_") // . + result = strings.ReplaceAll(result, "%2e", "_") + result = strings.ReplaceAll(result, "%00", "_") // null byte + result = strings.ReplaceAll(result, "%", "_") // Catch remaining % to prevent double-encoding attacks + + // Replace characters that are problematic in filenames + // Keep @ and . in middle but replace other special characters + for _, char := range []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " ", "\x00"} { + result = strings.ReplaceAll(result, char, "_") + } + + // Prevent path traversal: replace leading dots in each path component + // This handles cases like "../../../etc/passwd" → "_.._.._.._etc_passwd" + parts := strings.Split(result, "_") + for i, part := range parts { + for strings.HasPrefix(part, ".") { + part = "_" + part[1:] + } + parts[i] = part + } + result = strings.Join(parts, "_") + + return result +} diff --git a/internal/auth/kiro/aws_auth.go b/internal/auth/kiro/aws_auth.go new file mode 100644 index 00000000..53c77a8b --- /dev/null +++ b/internal/auth/kiro/aws_auth.go @@ -0,0 +1,314 @@ +// Package kiro provides OAuth2 authentication functionality for AWS CodeWhisperer (Kiro) API. +// This package implements token loading, refresh, and API communication with CodeWhisperer. +package kiro + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // awsKiroEndpoint is used for CodeWhisperer management APIs (GetUsageLimits, ListProfiles, etc.) + // Note: This is different from the Amazon Q streaming endpoint (q.us-east-1.amazonaws.com) + // used in kiro_executor.go for GenerateAssistantResponse. Both endpoints are correct + // for their respective API operations. + awsKiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com" + defaultTokenFile = "~/.aws/sso/cache/kiro-auth-token.json" + targetGetUsage = "AmazonCodeWhispererService.GetUsageLimits" + targetListModels = "AmazonCodeWhispererService.ListAvailableModels" + targetGenerateChat = "AmazonCodeWhispererStreamingService.GenerateAssistantResponse" +) + +// KiroAuth handles AWS CodeWhisperer authentication and API communication. +// It provides methods for loading tokens, refreshing expired tokens, +// and communicating with the CodeWhisperer API. +type KiroAuth struct { + httpClient *http.Client + endpoint string +} + +// NewKiroAuth creates a new Kiro authentication service. +// It initializes the HTTP client with proxy settings from the configuration. +// +// Parameters: +// - cfg: The application configuration containing proxy settings +// +// Returns: +// - *KiroAuth: A new Kiro authentication service instance +func NewKiroAuth(cfg *config.Config) *KiroAuth { + return &KiroAuth{ + httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 120 * time.Second}), + endpoint: awsKiroEndpoint, + } +} + +// LoadTokenFromFile loads token data from a file path. +// This method reads and parses the token file, expanding ~ to the home directory. +// +// Parameters: +// - tokenFile: Path to the token file (supports ~ expansion) +// +// Returns: +// - *KiroTokenData: The parsed token data +// - error: An error if file reading or parsing fails +func (k *KiroAuth) LoadTokenFromFile(tokenFile string) (*KiroTokenData, error) { + // Expand ~ to home directory + if strings.HasPrefix(tokenFile, "~") { + home, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("failed to get home directory: %w", err) + } + tokenFile = filepath.Join(home, tokenFile[1:]) + } + + data, err := os.ReadFile(tokenFile) + if err != nil { + return nil, fmt.Errorf("failed to read token file: %w", err) + } + + var tokenData KiroTokenData + if err := json.Unmarshal(data, &tokenData); err != nil { + return nil, fmt.Errorf("failed to parse token file: %w", err) + } + + return &tokenData, nil +} + +// IsTokenExpired checks if the token has expired. +// This method parses the expiration timestamp and compares it with the current time. +// +// Parameters: +// - tokenData: The token data to check +// +// Returns: +// - bool: True if the token has expired, false otherwise +func (k *KiroAuth) IsTokenExpired(tokenData *KiroTokenData) bool { + if tokenData.ExpiresAt == "" { + return true + } + + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + // Try alternate format + expiresAt, err = time.Parse("2006-01-02T15:04:05.000Z", tokenData.ExpiresAt) + if err != nil { + return true + } + } + + return time.Now().After(expiresAt) +} + +// makeRequest sends a request to the CodeWhisperer API. +// This is an internal method for making authenticated API calls. +// +// Parameters: +// - ctx: The context for the request +// - target: The API target (e.g., "AmazonCodeWhispererService.GetUsageLimits") +// - accessToken: The OAuth access token +// - payload: The request payload +// +// Returns: +// - []byte: The response body +// - error: An error if the request fails +func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken string, payload interface{}) ([]byte, error) { + jsonBody, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, k.endpoint, strings.NewReader(string(jsonBody))) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", target) + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := k.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("failed to close response body: %v", errClose) + } + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + return body, nil +} + +// GetUsageLimits retrieves usage information from the CodeWhisperer API. +// This method fetches the current usage statistics and subscription information. +// +// Parameters: +// - ctx: The context for the request +// - tokenData: The token data containing access token and profile ARN +// +// Returns: +// - *KiroUsageInfo: The usage information +// - error: An error if the request fails +func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) (*KiroUsageInfo, error) { + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + "profileArn": tokenData.ProfileArn, + "resourceType": "AGENTIC_REQUEST", + } + + body, err := k.makeRequest(ctx, targetGetUsage, tokenData.AccessToken, payload) + if err != nil { + return nil, err + } + + var result struct { + SubscriptionInfo struct { + SubscriptionTitle string `json:"subscriptionTitle"` + } `json:"subscriptionInfo"` + UsageBreakdownList []struct { + CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"` + UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"` + } `json:"usageBreakdownList"` + NextDateReset float64 `json:"nextDateReset"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse usage response: %w", err) + } + + usage := &KiroUsageInfo{ + SubscriptionTitle: result.SubscriptionInfo.SubscriptionTitle, + NextReset: fmt.Sprintf("%v", result.NextDateReset), + } + + if len(result.UsageBreakdownList) > 0 { + usage.CurrentUsage = result.UsageBreakdownList[0].CurrentUsageWithPrecision + usage.UsageLimit = result.UsageBreakdownList[0].UsageLimitWithPrecision + } + + return usage, nil +} + +// ListAvailableModels retrieves available models from the CodeWhisperer API. +// This method fetches the list of AI models available for the authenticated user. +// +// Parameters: +// - ctx: The context for the request +// - tokenData: The token data containing access token and profile ARN +// +// Returns: +// - []*KiroModel: The list of available models +// - error: An error if the request fails +func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroTokenData) ([]*KiroModel, error) { + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + "profileArn": tokenData.ProfileArn, + } + + body, err := k.makeRequest(ctx, targetListModels, tokenData.AccessToken, payload) + if err != nil { + return nil, err + } + + var result struct { + Models []struct { + ModelID string `json:"modelId"` + ModelName string `json:"modelName"` + Description string `json:"description"` + RateMultiplier float64 `json:"rateMultiplier"` + RateUnit string `json:"rateUnit"` + TokenLimits struct { + MaxInputTokens int `json:"maxInputTokens"` + } `json:"tokenLimits"` + } `json:"models"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse models response: %w", err) + } + + models := make([]*KiroModel, 0, len(result.Models)) + for _, m := range result.Models { + models = append(models, &KiroModel{ + ModelID: m.ModelID, + ModelName: m.ModelName, + Description: m.Description, + RateMultiplier: m.RateMultiplier, + RateUnit: m.RateUnit, + MaxInputTokens: m.TokenLimits.MaxInputTokens, + }) + } + + return models, nil +} + +// CreateTokenStorage creates a new KiroTokenStorage from token data. +// This method converts the token data into a storage structure suitable for persistence. +// +// Parameters: +// - tokenData: The token data to convert +// +// Returns: +// - *KiroTokenStorage: A new token storage instance +func (k *KiroAuth) CreateTokenStorage(tokenData *KiroTokenData) *KiroTokenStorage { + return &KiroTokenStorage{ + AccessToken: tokenData.AccessToken, + RefreshToken: tokenData.RefreshToken, + ProfileArn: tokenData.ProfileArn, + ExpiresAt: tokenData.ExpiresAt, + AuthMethod: tokenData.AuthMethod, + Provider: tokenData.Provider, + LastRefresh: time.Now().Format(time.RFC3339), + } +} + +// ValidateToken checks if the token is valid by making a test API call. +// This method verifies the token by attempting to fetch usage limits. +// +// Parameters: +// - ctx: The context for the request +// - tokenData: The token data to validate +// +// Returns: +// - error: An error if the token is invalid +func (k *KiroAuth) ValidateToken(ctx context.Context, tokenData *KiroTokenData) error { + _, err := k.GetUsageLimits(ctx, tokenData) + return err +} + +// UpdateTokenStorage updates an existing token storage with new token data. +// This method refreshes the token storage with newly obtained access and refresh tokens. +// +// Parameters: +// - storage: The existing token storage to update +// - tokenData: The new token data to apply +func (k *KiroAuth) UpdateTokenStorage(storage *KiroTokenStorage, tokenData *KiroTokenData) { + storage.AccessToken = tokenData.AccessToken + storage.RefreshToken = tokenData.RefreshToken + storage.ProfileArn = tokenData.ProfileArn + storage.ExpiresAt = tokenData.ExpiresAt + storage.AuthMethod = tokenData.AuthMethod + storage.Provider = tokenData.Provider + storage.LastRefresh = time.Now().Format(time.RFC3339) +} diff --git a/internal/auth/kiro/aws_test.go b/internal/auth/kiro/aws_test.go new file mode 100644 index 00000000..5f60294c --- /dev/null +++ b/internal/auth/kiro/aws_test.go @@ -0,0 +1,161 @@ +package kiro + +import ( + "encoding/base64" + "encoding/json" + "testing" +) + +func TestExtractEmailFromJWT(t *testing.T) { + tests := []struct { + name string + token string + expected string + }{ + { + name: "Empty token", + token: "", + expected: "", + }, + { + name: "Invalid token format", + token: "not.a.valid.jwt", + expected: "", + }, + { + name: "Invalid token - not base64", + token: "xxx.yyy.zzz", + expected: "", + }, + { + name: "Valid JWT with email", + token: createTestJWT(map[string]any{"email": "test@example.com", "sub": "user123"}), + expected: "test@example.com", + }, + { + name: "JWT without email but with preferred_username", + token: createTestJWT(map[string]any{"preferred_username": "user@domain.com", "sub": "user123"}), + expected: "user@domain.com", + }, + { + name: "JWT with email-like sub", + token: createTestJWT(map[string]any{"sub": "another@test.com"}), + expected: "another@test.com", + }, + { + name: "JWT without any email fields", + token: createTestJWT(map[string]any{"sub": "user123", "name": "Test User"}), + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ExtractEmailFromJWT(tt.token) + if result != tt.expected { + t.Errorf("ExtractEmailFromJWT() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestSanitizeEmailForFilename(t *testing.T) { + tests := []struct { + name string + email string + expected string + }{ + { + name: "Empty email", + email: "", + expected: "", + }, + { + name: "Simple email", + email: "user@example.com", + expected: "user@example.com", + }, + { + name: "Email with space", + email: "user name@example.com", + expected: "user_name@example.com", + }, + { + name: "Email with special chars", + email: "user:name@example.com", + expected: "user_name@example.com", + }, + { + name: "Email with multiple special chars", + email: "user/name:test@example.com", + expected: "user_name_test@example.com", + }, + { + name: "Path traversal attempt", + email: "../../../etc/passwd", + expected: "_.__.__._etc_passwd", + }, + { + name: "Path traversal with backslash", + email: `..\..\..\..\windows\system32`, + expected: "_.__.__.__._windows_system32", + }, + { + name: "Null byte injection attempt", + email: "user\x00@evil.com", + expected: "user_@evil.com", + }, + // URL-encoded path traversal tests + { + name: "URL-encoded slash", + email: "user%2Fpath@example.com", + expected: "user_path@example.com", + }, + { + name: "URL-encoded backslash", + email: "user%5Cpath@example.com", + expected: "user_path@example.com", + }, + { + name: "URL-encoded dot", + email: "%2E%2E%2Fetc%2Fpasswd", + expected: "___etc_passwd", + }, + { + name: "URL-encoded null", + email: "user%00@evil.com", + expected: "user_@evil.com", + }, + { + name: "Double URL-encoding attack", + email: "%252F%252E%252E", + expected: "_252F_252E_252E", // % replaced with _, remaining chars preserved (safe) + }, + { + name: "Mixed case URL-encoding", + email: "%2f%2F%5c%5C", + expected: "____", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SanitizeEmailForFilename(tt.email) + if result != tt.expected { + t.Errorf("SanitizeEmailForFilename() = %q, want %q", result, tt.expected) + } + }) + } +} + +// createTestJWT creates a test JWT token with the given claims +func createTestJWT(claims map[string]any) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`)) + + payloadBytes, _ := json.Marshal(claims) + payload := base64.RawURLEncoding.EncodeToString(payloadBytes) + + signature := base64.RawURLEncoding.EncodeToString([]byte("fake-signature")) + + return header + "." + payload + "." + signature +} diff --git a/internal/auth/kiro/oauth.go b/internal/auth/kiro/oauth.go new file mode 100644 index 00000000..e828da14 --- /dev/null +++ b/internal/auth/kiro/oauth.go @@ -0,0 +1,296 @@ +// Package kiro provides OAuth2 authentication for Kiro using native Google login. +package kiro + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "html" + "io" + "net" + "net/http" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // Kiro auth endpoint + kiroAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev" + + // Default callback port + defaultCallbackPort = 9876 + + // Auth timeout + authTimeout = 10 * time.Minute +) + +// KiroTokenResponse represents the response from Kiro token endpoint. +type KiroTokenResponse struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ProfileArn string `json:"profileArn"` + ExpiresIn int `json:"expiresIn"` +} + +// KiroOAuth handles the OAuth flow for Kiro authentication. +type KiroOAuth struct { + httpClient *http.Client + cfg *config.Config +} + +// NewKiroOAuth creates a new Kiro OAuth handler. +func NewKiroOAuth(cfg *config.Config) *KiroOAuth { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &KiroOAuth{ + httpClient: client, + cfg: cfg, + } +} + +// generateCodeVerifier generates a random code verifier for PKCE. +func generateCodeVerifier() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// generateCodeChallenge generates the code challenge from verifier. +func generateCodeChallenge(verifier string) string { + h := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(h[:]) +} + +// generateState generates a random state parameter. +func generateState() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// AuthResult contains the authorization code and state from callback. +type AuthResult struct { + Code string + State string + Error string +} + +// startCallbackServer starts a local HTTP server to receive the OAuth callback. +func (o *KiroOAuth) startCallbackServer(ctx context.Context, expectedState string) (string, <-chan AuthResult, error) { + // Try to find an available port - use localhost like Kiro does + listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", defaultCallbackPort)) + if err != nil { + // Try with dynamic port (RFC 8252 allows dynamic ports for native apps) + log.Warnf("kiro oauth: default port %d is busy, falling back to dynamic port", defaultCallbackPort) + listener, err = net.Listen("tcp", "localhost:0") + if err != nil { + return "", nil, fmt.Errorf("failed to start callback server: %w", err) + } + } + + port := listener.Addr().(*net.TCPAddr).Port + // Use http scheme for local callback server + redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port) + resultChan := make(chan AuthResult, 1) + + server := &http.Server{ + ReadHeaderTimeout: 10 * time.Second, + } + + mux := http.NewServeMux() + mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + errParam := r.URL.Query().Get("error") + + if errParam != "" { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, `

Login Failed

%s

You can close this window.

`, html.EscapeString(errParam)) + resultChan <- AuthResult{Error: errParam} + return + } + + if state != expectedState { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, `

Login Failed

Invalid state parameter

You can close this window.

`) + resultChan <- AuthResult{Error: "state mismatch"} + return + } + + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `

Login Successful!

You can close this window and return to the terminal.

`) + resultChan <- AuthResult{Code: code, State: state} + }) + + server.Handler = mux + + go func() { + if err := server.Serve(listener); err != nil && err != http.ErrServerClosed { + log.Debugf("callback server error: %v", err) + } + }() + + go func() { + select { + case <-ctx.Done(): + case <-time.After(authTimeout): + case <-resultChan: + } + _ = server.Shutdown(context.Background()) + }() + + return redirectURI, resultChan, nil +} + +// LoginWithBuilderID performs OAuth login with AWS Builder ID using device code flow. +func (o *KiroOAuth) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) { + ssoClient := NewSSOOIDCClient(o.cfg) + return ssoClient.LoginWithBuilderID(ctx) +} + +// exchangeCodeForToken exchanges the authorization code for tokens. +func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier, redirectURI string) (*KiroTokenData, error) { + payload := map[string]string{ + "code": code, + "code_verifier": codeVerifier, + "redirect_uri": redirectURI, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + tokenURL := kiroAuthEndpoint + "/oauth/token" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "cli-proxy-api/1.0.0") + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("token request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode) + } + + var tokenResp KiroTokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: "", // Caller should preserve original provider + }, nil +} + +// RefreshToken refreshes an expired access token. +func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) { + payload := map[string]string{ + "refreshToken": refreshToken, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + refreshURL := kiroAuthEndpoint + "/refreshToken" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "cli-proxy-api/1.0.0") + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("refresh request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + } + + var tokenResp KiroTokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: "", // Caller should preserve original provider + }, nil +} + +// LoginWithGoogle performs OAuth login with Google using Kiro's social auth. +// This uses a custom protocol handler (kiro://) to receive the callback. +func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) { + socialClient := NewSocialAuthClient(o.cfg) + return socialClient.LoginWithGoogle(ctx) +} + +// LoginWithGitHub performs OAuth login with GitHub using Kiro's social auth. +// This uses a custom protocol handler (kiro://) to receive the callback. +func (o *KiroOAuth) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) { + socialClient := NewSocialAuthClient(o.cfg) + return socialClient.LoginWithGitHub(ctx) +} diff --git a/internal/auth/kiro/protocol_handler.go b/internal/auth/kiro/protocol_handler.go new file mode 100644 index 00000000..d900ee33 --- /dev/null +++ b/internal/auth/kiro/protocol_handler.go @@ -0,0 +1,725 @@ +// Package kiro provides custom protocol handler registration for Kiro OAuth. +// This enables the CLI to intercept kiro:// URIs for social authentication (Google/GitHub). +package kiro + +import ( + "context" + "fmt" + "html" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + // KiroProtocol is the custom URI scheme used by Kiro + KiroProtocol = "kiro" + + // KiroAuthority is the URI authority for authentication callbacks + KiroAuthority = "kiro.kiroAgent" + + // KiroAuthPath is the path for successful authentication + KiroAuthPath = "/authenticate-success" + + // KiroRedirectURI is the full redirect URI for social auth + KiroRedirectURI = "kiro://kiro.kiroAgent/authenticate-success" + + // DefaultHandlerPort is the default port for the local callback server + DefaultHandlerPort = 19876 + + // HandlerTimeout is how long to wait for the OAuth callback + HandlerTimeout = 10 * time.Minute +) + +// ProtocolHandler manages the custom kiro:// protocol handler for OAuth callbacks. +type ProtocolHandler struct { + port int + server *http.Server + listener net.Listener + resultChan chan *AuthCallback + stopChan chan struct{} + mu sync.Mutex + running bool +} + +// AuthCallback contains the OAuth callback parameters. +type AuthCallback struct { + Code string + State string + Error string +} + +// NewProtocolHandler creates a new protocol handler. +func NewProtocolHandler() *ProtocolHandler { + return &ProtocolHandler{ + port: DefaultHandlerPort, + resultChan: make(chan *AuthCallback, 1), + stopChan: make(chan struct{}), + } +} + +// Start starts the local callback server that receives redirects from the protocol handler. +func (h *ProtocolHandler) Start(ctx context.Context) (int, error) { + h.mu.Lock() + defer h.mu.Unlock() + + if h.running { + return h.port, nil + } + + // Drain any stale results from previous runs + select { + case <-h.resultChan: + default: + } + + // Reset stopChan for reuse - close old channel first to unblock any waiting goroutines + if h.stopChan != nil { + select { + case <-h.stopChan: + // Already closed + default: + close(h.stopChan) + } + } + h.stopChan = make(chan struct{}) + + // Try ports in known range (must match handler script port range) + var listener net.Listener + var err error + portRange := []int{DefaultHandlerPort, DefaultHandlerPort + 1, DefaultHandlerPort + 2, DefaultHandlerPort + 3, DefaultHandlerPort + 4} + + for _, port := range portRange { + listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err == nil { + break + } + log.Debugf("kiro protocol handler: port %d busy, trying next", port) + } + + if listener == nil { + return 0, fmt.Errorf("failed to start callback server: all ports %d-%d are busy", DefaultHandlerPort, DefaultHandlerPort+4) + } + + h.listener = listener + h.port = listener.Addr().(*net.TCPAddr).Port + + mux := http.NewServeMux() + mux.HandleFunc("/oauth/callback", h.handleCallback) + + h.server = &http.Server{ + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + } + + go func() { + if err := h.server.Serve(listener); err != nil && err != http.ErrServerClosed { + log.Debugf("kiro protocol handler server error: %v", err) + } + }() + + h.running = true + log.Debugf("kiro protocol handler started on port %d", h.port) + + // Auto-shutdown after context done, timeout, or explicit stop + // Capture references to prevent race with new Start() calls + currentStopChan := h.stopChan + currentServer := h.server + currentListener := h.listener + go func() { + select { + case <-ctx.Done(): + case <-time.After(HandlerTimeout): + case <-currentStopChan: + return // Already stopped, exit goroutine + } + // Only stop if this is still the current server/listener instance + h.mu.Lock() + if h.server == currentServer && h.listener == currentListener { + h.mu.Unlock() + h.Stop() + } else { + h.mu.Unlock() + } + }() + + return h.port, nil +} + +// Stop stops the callback server. +func (h *ProtocolHandler) Stop() { + h.mu.Lock() + defer h.mu.Unlock() + + if !h.running { + return + } + + // Signal the auto-shutdown goroutine to exit. + // This select pattern is safe because stopChan is only modified while holding h.mu, + // and we hold the lock here. The select prevents panic from double-close. + select { + case <-h.stopChan: + // Already closed + default: + close(h.stopChan) + } + + if h.server != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = h.server.Shutdown(ctx) + } + + h.running = false + log.Debug("kiro protocol handler stopped") +} + +// WaitForCallback waits for the OAuth callback and returns the result. +func (h *ProtocolHandler) WaitForCallback(ctx context.Context) (*AuthCallback, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(HandlerTimeout): + return nil, fmt.Errorf("timeout waiting for OAuth callback") + case result := <-h.resultChan: + return result, nil + } +} + +// GetPort returns the port the handler is listening on. +func (h *ProtocolHandler) GetPort() int { + return h.port +} + +// handleCallback processes the OAuth callback from the protocol handler script. +func (h *ProtocolHandler) handleCallback(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + errParam := r.URL.Query().Get("error") + + result := &AuthCallback{ + Code: code, + State: state, + Error: errParam, + } + + // Send result + select { + case h.resultChan <- result: + default: + // Channel full, ignore duplicate callbacks + } + + // Send success response + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if errParam != "" { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, ` + +Login Failed + +

Login Failed

+

Error: %s

+

You can close this window.

+ +`, html.EscapeString(errParam)) + } else { + fmt.Fprint(w, ` + +Login Successful + +

Login Successful!

+

You can close this window and return to the terminal.

+ + +`) + } +} + +// IsProtocolHandlerInstalled checks if the kiro:// protocol handler is installed. +func IsProtocolHandlerInstalled() bool { + switch runtime.GOOS { + case "linux": + return isLinuxHandlerInstalled() + case "windows": + return isWindowsHandlerInstalled() + case "darwin": + return isDarwinHandlerInstalled() + default: + return false + } +} + +// InstallProtocolHandler installs the kiro:// protocol handler for the current platform. +func InstallProtocolHandler(handlerPort int) error { + switch runtime.GOOS { + case "linux": + return installLinuxHandler(handlerPort) + case "windows": + return installWindowsHandler(handlerPort) + case "darwin": + return installDarwinHandler(handlerPort) + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } +} + +// UninstallProtocolHandler removes the kiro:// protocol handler. +func UninstallProtocolHandler() error { + switch runtime.GOOS { + case "linux": + return uninstallLinuxHandler() + case "windows": + return uninstallWindowsHandler() + case "darwin": + return uninstallDarwinHandler() + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } +} + +// --- Linux Implementation --- + +func getLinuxDesktopPath() string { + homeDir, _ := os.UserHomeDir() + return filepath.Join(homeDir, ".local", "share", "applications", "kiro-oauth-handler.desktop") +} + +func getLinuxHandlerScriptPath() string { + homeDir, _ := os.UserHomeDir() + return filepath.Join(homeDir, ".local", "bin", "kiro-oauth-handler") +} + +func isLinuxHandlerInstalled() bool { + desktopPath := getLinuxDesktopPath() + _, err := os.Stat(desktopPath) + return err == nil +} + +func installLinuxHandler(handlerPort int) error { + // Create directories + homeDir, err := os.UserHomeDir() + if err != nil { + return err + } + + binDir := filepath.Join(homeDir, ".local", "bin") + appDir := filepath.Join(homeDir, ".local", "share", "applications") + + if err := os.MkdirAll(binDir, 0755); err != nil { + return fmt.Errorf("failed to create bin directory: %w", err) + } + if err := os.MkdirAll(appDir, 0755); err != nil { + return fmt.Errorf("failed to create applications directory: %w", err) + } + + // Create handler script - tries multiple ports to handle dynamic port allocation + scriptPath := getLinuxHandlerScriptPath() + scriptContent := fmt.Sprintf(`#!/bin/bash +# Kiro OAuth Protocol Handler +# Handles kiro:// URIs - tries CLI first, then forwards to Kiro IDE + +URL="$1" + +# Check curl availability +if ! command -v curl &> /dev/null; then + echo "Error: curl is required for Kiro OAuth handler" >&2 + exit 1 +fi + +# Extract code and state from URL +[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}" +[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}" +[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}" + +# Try CLI proxy on multiple possible ports (default + dynamic range) +CLI_OK=0 +for PORT in %d %d %d %d %d; do + if [ -n "$ERROR" ]; then + curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && CLI_OK=1 && break + elif [ -n "$CODE" ] && [ -n "$STATE" ]; then + curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && CLI_OK=1 && break + fi +done + +# If CLI not available, forward to Kiro IDE +if [ $CLI_OK -eq 0 ] && [ -x "/usr/share/kiro/kiro" ]; then + /usr/share/kiro/kiro --open-url "$URL" & +fi +`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4) + + if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil { + return fmt.Errorf("failed to write handler script: %w", err) + } + + // Create .desktop file + desktopPath := getLinuxDesktopPath() + desktopContent := fmt.Sprintf(`[Desktop Entry] +Name=Kiro OAuth Handler +Comment=Handle kiro:// protocol for CLI Proxy API authentication +Exec=%s %%u +Type=Application +Terminal=false +NoDisplay=true +MimeType=x-scheme-handler/kiro; +Categories=Utility; +`, scriptPath) + + if err := os.WriteFile(desktopPath, []byte(desktopContent), 0644); err != nil { + return fmt.Errorf("failed to write desktop file: %w", err) + } + + // Register handler with xdg-mime + cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro") + if err := cmd.Run(); err != nil { + log.Warnf("xdg-mime registration failed (may need manual setup): %v", err) + } + + // Update desktop database + cmd = exec.Command("update-desktop-database", appDir) + _ = cmd.Run() // Ignore errors, not critical + + log.Info("Kiro protocol handler installed for Linux") + return nil +} + +func uninstallLinuxHandler() error { + desktopPath := getLinuxDesktopPath() + scriptPath := getLinuxHandlerScriptPath() + + if err := os.Remove(desktopPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove desktop file: %w", err) + } + if err := os.Remove(scriptPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove handler script: %w", err) + } + + log.Info("Kiro protocol handler uninstalled") + return nil +} + +// --- Windows Implementation --- + +func isWindowsHandlerInstalled() bool { + // Check registry key existence + cmd := exec.Command("reg", "query", `HKCU\Software\Classes\kiro`, "/ve") + return cmd.Run() == nil +} + +func installWindowsHandler(handlerPort int) error { + homeDir, err := os.UserHomeDir() + if err != nil { + return err + } + + // Create handler script (PowerShell) + scriptDir := filepath.Join(homeDir, ".cliproxyapi") + if err := os.MkdirAll(scriptDir, 0755); err != nil { + return fmt.Errorf("failed to create script directory: %w", err) + } + + scriptPath := filepath.Join(scriptDir, "kiro-oauth-handler.ps1") + scriptContent := fmt.Sprintf(`# Kiro OAuth Protocol Handler for Windows +param([string]$url) + +# Load required assembly for HttpUtility +Add-Type -AssemblyName System.Web + +# Parse URL parameters +$uri = [System.Uri]$url +$query = [System.Web.HttpUtility]::ParseQueryString($uri.Query) +$code = $query["code"] +$state = $query["state"] +$errorParam = $query["error"] + +# Try multiple ports (default + dynamic range) +$ports = @(%d, %d, %d, %d, %d) +$success = $false + +foreach ($port in $ports) { + if ($success) { break } + $callbackUrl = "http://127.0.0.1:$port/oauth/callback" + try { + if ($errorParam) { + $fullUrl = $callbackUrl + "?error=" + $errorParam + Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null + $success = $true + } elseif ($code -and $state) { + $fullUrl = $callbackUrl + "?code=" + $code + "&state=" + $state + Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null + $success = $true + } + } catch { + # Try next port + } +} +`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4) + + if err := os.WriteFile(scriptPath, []byte(scriptContent), 0644); err != nil { + return fmt.Errorf("failed to write handler script: %w", err) + } + + // Create batch wrapper + batchPath := filepath.Join(scriptDir, "kiro-oauth-handler.bat") + batchContent := fmt.Sprintf("@echo off\npowershell -ExecutionPolicy Bypass -File \"%s\" %%1\n", scriptPath) + + if err := os.WriteFile(batchPath, []byte(batchContent), 0644); err != nil { + return fmt.Errorf("failed to write batch wrapper: %w", err) + } + + // Register in Windows registry + commands := [][]string{ + {"reg", "add", `HKCU\Software\Classes\kiro`, "/ve", "/d", "URL:Kiro Protocol", "/f"}, + {"reg", "add", `HKCU\Software\Classes\kiro`, "/v", "URL Protocol", "/d", "", "/f"}, + {"reg", "add", `HKCU\Software\Classes\kiro\shell`, "/f"}, + {"reg", "add", `HKCU\Software\Classes\kiro\shell\open`, "/f"}, + {"reg", "add", `HKCU\Software\Classes\kiro\shell\open\command`, "/ve", "/d", fmt.Sprintf("\"%s\" \"%%1\"", batchPath), "/f"}, + } + + for _, args := range commands { + cmd := exec.Command(args[0], args[1:]...) + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to run registry command: %w", err) + } + } + + log.Info("Kiro protocol handler installed for Windows") + return nil +} + +func uninstallWindowsHandler() error { + // Remove registry keys + cmd := exec.Command("reg", "delete", `HKCU\Software\Classes\kiro`, "/f") + if err := cmd.Run(); err != nil { + log.Warnf("failed to remove registry key: %v", err) + } + + // Remove scripts + homeDir, _ := os.UserHomeDir() + scriptDir := filepath.Join(homeDir, ".cliproxyapi") + _ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.ps1")) + _ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.bat")) + + log.Info("Kiro protocol handler uninstalled") + return nil +} + +// --- macOS Implementation --- + +func getDarwinAppPath() string { + homeDir, _ := os.UserHomeDir() + return filepath.Join(homeDir, "Applications", "KiroOAuthHandler.app") +} + +func isDarwinHandlerInstalled() bool { + appPath := getDarwinAppPath() + _, err := os.Stat(appPath) + return err == nil +} + +func installDarwinHandler(handlerPort int) error { + // Create app bundle structure + appPath := getDarwinAppPath() + contentsPath := filepath.Join(appPath, "Contents") + macOSPath := filepath.Join(contentsPath, "MacOS") + + if err := os.MkdirAll(macOSPath, 0755); err != nil { + return fmt.Errorf("failed to create app bundle: %w", err) + } + + // Create Info.plist + plistPath := filepath.Join(contentsPath, "Info.plist") + plistContent := ` + + + + CFBundleIdentifier + com.cliproxyapi.kiro-oauth-handler + CFBundleName + KiroOAuthHandler + CFBundleExecutable + kiro-oauth-handler + CFBundleVersion + 1.0 + CFBundleURLTypes + + + CFBundleURLName + Kiro Protocol + CFBundleURLSchemes + + kiro + + + + LSBackgroundOnly + + +` + + if err := os.WriteFile(plistPath, []byte(plistContent), 0644); err != nil { + return fmt.Errorf("failed to write Info.plist: %w", err) + } + + // Create executable script - tries multiple ports to handle dynamic port allocation + execPath := filepath.Join(macOSPath, "kiro-oauth-handler") + execContent := fmt.Sprintf(`#!/bin/bash +# Kiro OAuth Protocol Handler for macOS + +URL="$1" + +# Check curl availability (should always exist on macOS) +if [ ! -x /usr/bin/curl ]; then + echo "Error: curl is required for Kiro OAuth handler" >&2 + exit 1 +fi + +# Extract code and state from URL +[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}" +[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}" +[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}" + +# Try multiple ports (default + dynamic range) +for PORT in %d %d %d %d %d; do + if [ -n "$ERROR" ]; then + /usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && exit 0 + elif [ -n "$CODE" ] && [ -n "$STATE" ]; then + /usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && exit 0 + fi +done +`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4) + + if err := os.WriteFile(execPath, []byte(execContent), 0755); err != nil { + return fmt.Errorf("failed to write executable: %w", err) + } + + // Register the app with Launch Services + cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister", + "-f", appPath) + if err := cmd.Run(); err != nil { + log.Warnf("lsregister failed (handler may still work): %v", err) + } + + log.Info("Kiro protocol handler installed for macOS") + return nil +} + +func uninstallDarwinHandler() error { + appPath := getDarwinAppPath() + + // Unregister from Launch Services + cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister", + "-u", appPath) + _ = cmd.Run() + + // Remove app bundle + if err := os.RemoveAll(appPath); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to remove app bundle: %w", err) + } + + log.Info("Kiro protocol handler uninstalled") + return nil +} + +// ParseKiroURI parses a kiro:// URI and extracts the callback parameters. +func ParseKiroURI(rawURI string) (*AuthCallback, error) { + u, err := url.Parse(rawURI) + if err != nil { + return nil, fmt.Errorf("invalid URI: %w", err) + } + + if u.Scheme != KiroProtocol { + return nil, fmt.Errorf("invalid scheme: expected %s, got %s", KiroProtocol, u.Scheme) + } + + if u.Host != KiroAuthority { + return nil, fmt.Errorf("invalid authority: expected %s, got %s", KiroAuthority, u.Host) + } + + query := u.Query() + return &AuthCallback{ + Code: query.Get("code"), + State: query.Get("state"), + Error: query.Get("error"), + }, nil +} + +// GetHandlerInstructions returns platform-specific instructions for manual handler setup. +func GetHandlerInstructions() string { + switch runtime.GOOS { + case "linux": + return `To manually set up the Kiro protocol handler on Linux: + +1. Create ~/.local/share/applications/kiro-oauth-handler.desktop: + [Desktop Entry] + Name=Kiro OAuth Handler + Exec=~/.local/bin/kiro-oauth-handler %u + Type=Application + Terminal=false + MimeType=x-scheme-handler/kiro; + +2. Create ~/.local/bin/kiro-oauth-handler (make it executable): + #!/bin/bash + URL="$1" + # ... (see generated script for full content) + +3. Run: xdg-mime default kiro-oauth-handler.desktop x-scheme-handler/kiro` + + case "windows": + return `To manually set up the Kiro protocol handler on Windows: + +1. Open Registry Editor (regedit.exe) +2. Create key: HKEY_CURRENT_USER\Software\Classes\kiro +3. Set default value to: URL:Kiro Protocol +4. Create string value "URL Protocol" with empty data +5. Create subkey: shell\open\command +6. Set default value to: "C:\path\to\handler.bat" "%1"` + + case "darwin": + return `To manually set up the Kiro protocol handler on macOS: + +1. Create ~/Applications/KiroOAuthHandler.app bundle +2. Add Info.plist with CFBundleURLTypes containing "kiro" scheme +3. Create executable in Contents/MacOS/ +4. Run: /System/Library/.../lsregister -f ~/Applications/KiroOAuthHandler.app` + + default: + return "Protocol handler setup is not supported on this platform." + } +} + +// SetupProtocolHandlerIfNeeded checks and installs the protocol handler if needed. +func SetupProtocolHandlerIfNeeded(handlerPort int) error { + if IsProtocolHandlerInstalled() { + log.Debug("Kiro protocol handler already installed") + return nil + } + + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Protocol Handler Setup Required ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + fmt.Println("\nTo enable Google/GitHub login, we need to install a protocol handler.") + fmt.Println("This allows your browser to redirect back to the CLI after authentication.") + fmt.Println("\nInstalling protocol handler...") + + if err := InstallProtocolHandler(handlerPort); err != nil { + fmt.Printf("\n⚠ Automatic installation failed: %v\n", err) + fmt.Println("\nManual setup instructions:") + fmt.Println(strings.Repeat("-", 60)) + fmt.Println(GetHandlerInstructions()) + return err + } + + fmt.Println("\n✓ Protocol handler installed successfully!") + return nil +} diff --git a/internal/auth/kiro/social_auth.go b/internal/auth/kiro/social_auth.go new file mode 100644 index 00000000..2ac29bf8 --- /dev/null +++ b/internal/auth/kiro/social_auth.go @@ -0,0 +1,403 @@ +// Package kiro provides social authentication (Google/GitHub) for Kiro via AuthServiceClient. +package kiro + +import ( + "bufio" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "os/exec" + "runtime" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" + "golang.org/x/term" +) + +const ( + // Kiro AuthService endpoint + kiroAuthServiceEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev" + + // OAuth timeout + socialAuthTimeout = 10 * time.Minute +) + +// SocialProvider represents the social login provider. +type SocialProvider string + +const ( + // ProviderGoogle is Google OAuth provider + ProviderGoogle SocialProvider = "Google" + // ProviderGitHub is GitHub OAuth provider + ProviderGitHub SocialProvider = "Github" + // Note: AWS Builder ID is NOT supported by Kiro's auth service. + // It only supports: Google, Github, Cognito + // AWS Builder ID must use device code flow via SSO OIDC. +) + +// CreateTokenRequest is sent to Kiro's /oauth/token endpoint. +type CreateTokenRequest struct { + Code string `json:"code"` + CodeVerifier string `json:"code_verifier"` + RedirectURI string `json:"redirect_uri"` + InvitationCode string `json:"invitation_code,omitempty"` +} + +// SocialTokenResponse from Kiro's /oauth/token endpoint for social auth. +type SocialTokenResponse struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ProfileArn string `json:"profileArn"` + ExpiresIn int `json:"expiresIn"` +} + +// RefreshTokenRequest is sent to Kiro's /refreshToken endpoint. +type RefreshTokenRequest struct { + RefreshToken string `json:"refreshToken"` +} + +// SocialAuthClient handles social authentication with Kiro. +type SocialAuthClient struct { + httpClient *http.Client + cfg *config.Config + protocolHandler *ProtocolHandler +} + +// NewSocialAuthClient creates a new social auth client. +func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &SocialAuthClient{ + httpClient: client, + cfg: cfg, + protocolHandler: NewProtocolHandler(), + } +} + +// generatePKCE generates PKCE code verifier and challenge. +func generatePKCE() (verifier, challenge string, err error) { + // Generate 32 bytes of random data for verifier + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", "", fmt.Errorf("failed to generate random bytes: %w", err) + } + verifier = base64.RawURLEncoding.EncodeToString(b) + + // Generate SHA256 hash of verifier for challenge + h := sha256.Sum256([]byte(verifier)) + challenge = base64.RawURLEncoding.EncodeToString(h[:]) + + return verifier, challenge, nil +} + +// generateState generates a random state parameter. +func generateStateParam() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// buildLoginURL constructs the Kiro OAuth login URL. +// The login endpoint expects a GET request with query parameters. +// Format: /login?idp=Google&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=...&prompt=select_account +// The prompt=select_account parameter forces the account selection screen even if already logged in. +func (c *SocialAuthClient) buildLoginURL(provider, redirectURI, codeChallenge, state string) string { + return fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account", + kiroAuthServiceEndpoint, + provider, + url.QueryEscape(redirectURI), + codeChallenge, + state, + ) +} + +// CreateToken exchanges the authorization code for tokens. +func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal token request: %w", err) + } + + tokenURL := kiroAuthServiceEndpoint + "/oauth/token" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create token request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("token request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read token response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode) + } + + var tokenResp SocialTokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + return &tokenResp, nil +} + +// RefreshSocialToken refreshes an expired social auth token. +func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) { + body, err := json.Marshal(&RefreshTokenRequest{RefreshToken: refreshToken}) + if err != nil { + return nil, fmt.Errorf("failed to marshal refresh request: %w", err) + } + + refreshURL := kiroAuthServiceEndpoint + "/refreshToken" + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body))) + if err != nil { + return nil, fmt.Errorf("failed to create refresh request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("refresh request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read refresh response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + } + + var tokenResp SocialTokenResponse + if err := json.Unmarshal(respBody, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse refresh response: %w", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 // Default 1 hour + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: "", // Caller should preserve original provider + }, nil +} + +// LoginWithSocial performs OAuth login with Google. +func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) { + providerName := string(provider) + + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName) + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Step 1: Setup protocol handler + fmt.Println("\nSetting up authentication...") + + // Start the local callback server + handlerPort, err := c.protocolHandler.Start(ctx) + if err != nil { + return nil, fmt.Errorf("failed to start callback server: %w", err) + } + defer c.protocolHandler.Stop() + + // Ensure protocol handler is installed and set as default + if err := SetupProtocolHandlerIfNeeded(handlerPort); err != nil { + fmt.Println("\n⚠ Protocol handler setup failed. Trying alternative method...") + fmt.Println(" If you see a browser 'Open with' dialog, select your default browser.") + fmt.Println(" For manual setup instructions, run: cliproxy kiro --help-protocol") + log.Debugf("kiro: protocol handler setup error: %v", err) + // Continue anyway - user might have set it up manually or select browser manually + } else { + // Force set our handler as default (prevents "Open with" dialog) + forceDefaultProtocolHandler() + } + + // Step 2: Generate PKCE codes + codeVerifier, codeChallenge, err := generatePKCE() + if err != nil { + return nil, fmt.Errorf("failed to generate PKCE: %w", err) + } + + // Step 3: Generate state + state, err := generateStateParam() + if err != nil { + return nil, fmt.Errorf("failed to generate state: %w", err) + } + + // Step 4: Build the login URL (Kiro uses GET request with query params) + authURL := c.buildLoginURL(providerName, KiroRedirectURI, codeChallenge, state) + + // Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito) + // Incognito mode enables multi-account support by bypassing cached sessions + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + if !c.cfg.IncognitoBrowser { + log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.") + } else { + log.Debug("kiro: using incognito mode for multi-account support") + } + } else { + browser.SetIncognitoMode(true) // Default to incognito if no config + log.Debug("kiro: using incognito mode for multi-account support (default)") + } + + // Step 5: Open browser for user authentication + fmt.Println("\n════════════════════════════════════════════════════════════") + fmt.Printf(" Opening browser for %s authentication...\n", providerName) + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n URL: %s\n\n", authURL) + + if err := browser.OpenURL(authURL); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" ⚠ Could not open browser automatically.") + fmt.Println(" Please open the URL above in your browser manually.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + fmt.Println("\n Waiting for authentication callback...") + + // Step 6: Wait for callback + callback, err := c.protocolHandler.WaitForCallback(ctx) + if err != nil { + return nil, fmt.Errorf("failed to receive callback: %w", err) + } + + if callback.Error != "" { + return nil, fmt.Errorf("authentication error: %s", callback.Error) + } + + if callback.State != state { + // Log state values for debugging, but don't expose in user-facing error + log.Debugf("kiro: OAuth state mismatch - expected %s, got %s", state, callback.State) + return nil, fmt.Errorf("OAuth state validation failed - please try again") + } + + if callback.Code == "" { + return nil, fmt.Errorf("no authorization code received") + } + + fmt.Println("\n✓ Authorization received!") + + // Step 7: Exchange code for tokens + fmt.Println("Exchanging code for tokens...") + + tokenReq := &CreateTokenRequest{ + Code: callback.Code, + CodeVerifier: codeVerifier, + RedirectURI: KiroRedirectURI, + } + + tokenResp, err := c.CreateToken(ctx, tokenReq) + if err != nil { + return nil, fmt.Errorf("failed to exchange code for tokens: %w", err) + } + + fmt.Println("\n✓ Authentication successful!") + + // Close the browser window + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + // Validate ExpiresIn - use default 1 hour if invalid + expiresIn := tokenResp.ExpiresIn + if expiresIn <= 0 { + expiresIn = 3600 + } + expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second) + + // Try to extract email from JWT access token first + email := ExtractEmailFromJWT(tokenResp.AccessToken) + + // If no email in JWT, ask user for account label (only in interactive mode) + if email == "" && isInteractiveTerminal() { + fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ") + reader := bufio.NewReader(os.Stdin) + var err error + email, err = reader.ReadString('\n') + if err != nil { + log.Debugf("Failed to read account label: %v", err) + } + email = strings.TrimSpace(email) + } + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: tokenResp.ProfileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "social", + Provider: providerName, + Email: email, // JWT email or user-provided label + }, nil +} + +// LoginWithGoogle performs OAuth login with Google. +func (c *SocialAuthClient) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) { + return c.LoginWithSocial(ctx, ProviderGoogle) +} + +// LoginWithGitHub performs OAuth login with GitHub. +func (c *SocialAuthClient) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) { + return c.LoginWithSocial(ctx, ProviderGitHub) +} + +// forceDefaultProtocolHandler sets our protocol handler as the default for kiro:// URLs. +// This prevents the "Open with" dialog from appearing on Linux. +// On non-Linux platforms, this is a no-op as they use different mechanisms. +func forceDefaultProtocolHandler() { + if runtime.GOOS != "linux" { + return // Non-Linux platforms use different handler mechanisms + } + + // Set our handler as default using xdg-mime + cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro") + if err := cmd.Run(); err != nil { + log.Warnf("Failed to set default protocol handler: %v. You may see a handler selection dialog.", err) + } +} + +// isInteractiveTerminal checks if stdin is connected to an interactive terminal. +// Returns false in CI/automated environments or when stdin is piped. +func isInteractiveTerminal() bool { + return term.IsTerminal(int(os.Stdin.Fd())) +} diff --git a/internal/auth/kiro/sso_oidc.go b/internal/auth/kiro/sso_oidc.go new file mode 100644 index 00000000..d3c27d16 --- /dev/null +++ b/internal/auth/kiro/sso_oidc.go @@ -0,0 +1,527 @@ +// Package kiro provides AWS SSO OIDC authentication for Kiro. +package kiro + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // AWS SSO OIDC endpoints + ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com" + + // Kiro's start URL for Builder ID + builderIDStartURL = "https://view.awsapps.com/start" + + // Polling interval + pollInterval = 5 * time.Second +) + +// SSOOIDCClient handles AWS SSO OIDC authentication. +type SSOOIDCClient struct { + httpClient *http.Client + cfg *config.Config +} + +// NewSSOOIDCClient creates a new SSO OIDC client. +func NewSSOOIDCClient(cfg *config.Config) *SSOOIDCClient { + client := &http.Client{Timeout: 30 * time.Second} + if cfg != nil { + client = util.SetProxy(&cfg.SDKConfig, client) + } + return &SSOOIDCClient{ + httpClient: client, + cfg: cfg, + } +} + +// RegisterClientResponse from AWS SSO OIDC. +type RegisterClientResponse struct { + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + ClientIDIssuedAt int64 `json:"clientIdIssuedAt"` + ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"` +} + +// StartDeviceAuthResponse from AWS SSO OIDC. +type StartDeviceAuthResponse struct { + DeviceCode string `json:"deviceCode"` + UserCode string `json:"userCode"` + VerificationURI string `json:"verificationUri"` + VerificationURIComplete string `json:"verificationUriComplete"` + ExpiresIn int `json:"expiresIn"` + Interval int `json:"interval"` +} + +// CreateTokenResponse from AWS SSO OIDC. +type CreateTokenResponse struct { + AccessToken string `json:"accessToken"` + TokenType string `json:"tokenType"` + ExpiresIn int `json:"expiresIn"` + RefreshToken string `json:"refreshToken"` +} + +// RegisterClient registers a new OIDC client with AWS. +func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) { + // Generate unique client name for each registration to support multiple accounts + clientName := fmt.Sprintf("CLI-Proxy-API-%d", time.Now().UnixNano()) + + payload := map[string]interface{}{ + "clientName": clientName, + "clientType": "public", + "scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations"}, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/client/register", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode) + } + + var result RegisterClientResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// StartDeviceAuthorization starts the device authorization flow. +func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID, clientSecret string) (*StartDeviceAuthResponse, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "startUrl": builderIDStartURL, + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/device_authorization", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode) + } + + var result StartDeviceAuthResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// CreateToken polls for the access token after user authorization. +func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret, deviceCode string) (*CreateTokenResponse, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "deviceCode": deviceCode, + "grantType": "urn:ietf:params:oauth:grant-type:device_code", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // Check for pending authorization + if resp.StatusCode == http.StatusBadRequest { + var errResp struct { + Error string `json:"error"` + } + if json.Unmarshal(respBody, &errResp) == nil { + if errResp.Error == "authorization_pending" { + return nil, fmt.Errorf("authorization_pending") + } + if errResp.Error == "slow_down" { + return nil, fmt.Errorf("slow_down") + } + } + log.Debugf("create token failed: %s", string(respBody)) + return nil, fmt.Errorf("create token failed") + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + return &result, nil +} + +// RefreshToken refreshes an access token using the refresh token. +func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret, refreshToken string) (*KiroTokenData, error) { + payload := map[string]string{ + "clientId": clientID, + "clientSecret": clientSecret, + "refreshToken": refreshToken, + "grantType": "refresh_token", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body))) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody)) + return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode) + } + + var result CreateTokenResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, err + } + + expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: result.AccessToken, + RefreshToken: result.RefreshToken, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "builder-id", + Provider: "AWS", + ClientID: clientID, + ClientSecret: clientSecret, + }, nil +} + +// LoginWithBuilderID performs the full device code flow for AWS Builder ID. +func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) { + fmt.Println("\n╔══════════════════════════════════════════════════════════╗") + fmt.Println("║ Kiro Authentication (AWS Builder ID) ║") + fmt.Println("╚══════════════════════════════════════════════════════════╝") + + // Step 1: Register client + fmt.Println("\nRegistering client...") + regResp, err := c.RegisterClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to register client: %w", err) + } + log.Debugf("Client registered: %s", regResp.ClientID) + + // Step 2: Start device authorization + fmt.Println("Starting device authorization...") + authResp, err := c.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret) + if err != nil { + return nil, fmt.Errorf("failed to start device auth: %w", err) + } + + // Step 3: Show user the verification URL + fmt.Printf("\n") + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf(" Open this URL in your browser:\n") + fmt.Printf(" %s\n", authResp.VerificationURIComplete) + fmt.Println("════════════════════════════════════════════════════════════") + fmt.Printf("\n Or go to: %s\n", authResp.VerificationURI) + fmt.Printf(" And enter code: %s\n\n", authResp.UserCode) + + // Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito) + // Incognito mode enables multi-account support by bypassing cached sessions + if c.cfg != nil { + browser.SetIncognitoMode(c.cfg.IncognitoBrowser) + if !c.cfg.IncognitoBrowser { + log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.") + } else { + log.Debug("kiro: using incognito mode for multi-account support") + } + } else { + browser.SetIncognitoMode(true) // Default to incognito if no config + log.Debug("kiro: using incognito mode for multi-account support (default)") + } + + // Open browser using cross-platform browser package + if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil { + log.Warnf("Could not open browser automatically: %v", err) + fmt.Println(" Please open the URL manually in your browser.") + } else { + fmt.Println(" (Browser opened automatically)") + } + + // Step 4: Poll for token + fmt.Println("Waiting for authorization...") + + interval := pollInterval + if authResp.Interval > 0 { + interval = time.Duration(authResp.Interval) * time.Second + } + + deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second) + + for time.Now().Before(deadline) { + select { + case <-ctx.Done(): + browser.CloseBrowser() // Cleanup on cancel + return nil, ctx.Err() + case <-time.After(interval): + tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode) + if err != nil { + errStr := err.Error() + if strings.Contains(errStr, "authorization_pending") { + fmt.Print(".") + continue + } + if strings.Contains(errStr, "slow_down") { + interval += 5 * time.Second + continue + } + // Close browser on error before returning + browser.CloseBrowser() + return nil, fmt.Errorf("token creation failed: %w", err) + } + + fmt.Println("\n\n✓ Authorization successful!") + + // Close the browser window + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser: %v", err) + } + + // Step 5: Get profile ARN from CodeWhisperer API + fmt.Println("Fetching profile information...") + profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken) + + // Extract email from JWT access token + email := ExtractEmailFromJWT(tokenResp.AccessToken) + if email != "" { + fmt.Printf(" Logged in as: %s\n", email) + } + + expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + return &KiroTokenData{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ProfileArn: profileArn, + ExpiresAt: expiresAt.Format(time.RFC3339), + AuthMethod: "builder-id", + Provider: "AWS", + ClientID: regResp.ClientID, + ClientSecret: regResp.ClientSecret, + Email: email, + }, nil + } + } + + // Close browser on timeout for better UX + if err := browser.CloseBrowser(); err != nil { + log.Debugf("Failed to close browser on timeout: %v", err) + } + return nil, fmt.Errorf("authorization timed out") +} + +// fetchProfileArn retrieves the profile ARN from CodeWhisperer API. +// This is needed for file naming since AWS SSO OIDC doesn't return profile info. +func (c *SSOOIDCClient) fetchProfileArn(ctx context.Context, accessToken string) string { + // Try ListProfiles API first + profileArn := c.tryListProfiles(ctx, accessToken) + if profileArn != "" { + return profileArn + } + + // Fallback: Try ListAvailableCustomizations + return c.tryListCustomizations(ctx, accessToken) +} + +func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) string { + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + } + + body, err := json.Marshal(payload) + if err != nil { + return "" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body))) + if err != nil { + return "" + } + + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListProfiles") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "" + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + log.Debugf("ListProfiles failed (status %d): %s", resp.StatusCode, string(respBody)) + return "" + } + + log.Debugf("ListProfiles response: %s", string(respBody)) + + var result struct { + Profiles []struct { + Arn string `json:"arn"` + } `json:"profiles"` + ProfileArn string `json:"profileArn"` + } + + if err := json.Unmarshal(respBody, &result); err != nil { + return "" + } + + if result.ProfileArn != "" { + return result.ProfileArn + } + + if len(result.Profiles) > 0 { + return result.Profiles[0].Arn + } + + return "" +} + +func (c *SSOOIDCClient) tryListCustomizations(ctx context.Context, accessToken string) string { + payload := map[string]interface{}{ + "origin": "AI_EDITOR", + } + + body, err := json.Marshal(payload) + if err != nil { + return "" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body))) + if err != nil { + return "" + } + + req.Header.Set("Content-Type", "application/x-amz-json-1.0") + req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListAvailableCustomizations") + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "" + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + log.Debugf("ListAvailableCustomizations failed (status %d): %s", resp.StatusCode, string(respBody)) + return "" + } + + log.Debugf("ListAvailableCustomizations response: %s", string(respBody)) + + var result struct { + Customizations []struct { + Arn string `json:"arn"` + } `json:"customizations"` + ProfileArn string `json:"profileArn"` + } + + if err := json.Unmarshal(respBody, &result); err != nil { + return "" + } + + if result.ProfileArn != "" { + return result.ProfileArn + } + + if len(result.Customizations) > 0 { + return result.Customizations[0].Arn + } + + return "" +} diff --git a/internal/auth/kiro/token.go b/internal/auth/kiro/token.go new file mode 100644 index 00000000..e83b1728 --- /dev/null +++ b/internal/auth/kiro/token.go @@ -0,0 +1,72 @@ +package kiro + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" +) + +// KiroTokenStorage holds the persistent token data for Kiro authentication. +type KiroTokenStorage struct { + // AccessToken is the OAuth2 access token for API access + AccessToken string `json:"access_token"` + // RefreshToken is used to obtain new access tokens + RefreshToken string `json:"refresh_token"` + // ProfileArn is the AWS CodeWhisperer profile ARN + ProfileArn string `json:"profile_arn"` + // ExpiresAt is the timestamp when the token expires + ExpiresAt string `json:"expires_at"` + // AuthMethod indicates the authentication method used + AuthMethod string `json:"auth_method"` + // Provider indicates the OAuth provider + Provider string `json:"provider"` + // LastRefresh is the timestamp of the last token refresh + LastRefresh string `json:"last_refresh"` +} + +// SaveTokenToFile persists the token storage to the specified file path. +func (s *KiroTokenStorage) SaveTokenToFile(authFilePath string) error { + dir := filepath.Dir(authFilePath) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + data, err := json.MarshalIndent(s, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal token storage: %w", err) + } + + if err := os.WriteFile(authFilePath, data, 0600); err != nil { + return fmt.Errorf("failed to write token file: %w", err) + } + + return nil +} + +// LoadFromFile loads token storage from the specified file path. +func LoadFromFile(authFilePath string) (*KiroTokenStorage, error) { + data, err := os.ReadFile(authFilePath) + if err != nil { + return nil, fmt.Errorf("failed to read token file: %w", err) + } + + var storage KiroTokenStorage + if err := json.Unmarshal(data, &storage); err != nil { + return nil, fmt.Errorf("failed to parse token file: %w", err) + } + + return &storage, nil +} + +// ToTokenData converts storage to KiroTokenData for API use. +func (s *KiroTokenStorage) ToTokenData() *KiroTokenData { + return &KiroTokenData{ + AccessToken: s.AccessToken, + RefreshToken: s.RefreshToken, + ProfileArn: s.ProfileArn, + ExpiresAt: s.ExpiresAt, + AuthMethod: s.AuthMethod, + Provider: s.Provider, + } +} diff --git a/internal/browser/browser.go b/internal/browser/browser.go index b24dc5e1..3a5aeea7 100644 --- a/internal/browser/browser.go +++ b/internal/browser/browser.go @@ -6,14 +6,49 @@ import ( "fmt" "os/exec" "runtime" + "strings" + "sync" + pkgbrowser "github.com/pkg/browser" log "github.com/sirupsen/logrus" - "github.com/skratchdot/open-golang/open" ) +// incognitoMode controls whether to open URLs in incognito/private mode. +// This is useful for OAuth flows where you want to use a different account. +var incognitoMode bool + +// lastBrowserProcess stores the last opened browser process for cleanup +var lastBrowserProcess *exec.Cmd +var browserMutex sync.Mutex + +// SetIncognitoMode enables or disables incognito/private browsing mode. +func SetIncognitoMode(enabled bool) { + incognitoMode = enabled +} + +// IsIncognitoMode returns whether incognito mode is enabled. +func IsIncognitoMode() bool { + return incognitoMode +} + +// CloseBrowser closes the last opened browser process. +func CloseBrowser() error { + browserMutex.Lock() + defer browserMutex.Unlock() + + if lastBrowserProcess == nil || lastBrowserProcess.Process == nil { + return nil + } + + err := lastBrowserProcess.Process.Kill() + lastBrowserProcess = nil + return err +} + // OpenURL opens the specified URL in the default web browser. -// It first attempts to use a platform-agnostic library and falls back to -// platform-specific commands if that fails. +// It uses the pkg/browser library which provides robust cross-platform support +// for Windows, macOS, and Linux. +// If incognito mode is enabled, it will open in a private/incognito window. // // Parameters: // - url: The URL to open. @@ -21,16 +56,22 @@ import ( // Returns: // - An error if the URL cannot be opened, otherwise nil. func OpenURL(url string) error { - fmt.Printf("Attempting to open URL in browser: %s\n", url) + log.Debugf("Opening URL in browser: %s (incognito=%v)", url, incognitoMode) - // Try using the open-golang library first - err := open.Run(url) + // If incognito mode is enabled, use platform-specific incognito commands + if incognitoMode { + log.Debug("Using incognito mode") + return openURLIncognito(url) + } + + // Use pkg/browser for cross-platform support + err := pkgbrowser.OpenURL(url) if err == nil { - log.Debug("Successfully opened URL using open-golang library") + log.Debug("Successfully opened URL using pkg/browser library") return nil } - log.Debugf("open-golang failed: %v, trying platform-specific commands", err) + log.Debugf("pkg/browser failed: %v, trying platform-specific commands", err) // Fallback to platform-specific commands return openURLPlatformSpecific(url) @@ -78,18 +119,379 @@ func openURLPlatformSpecific(url string) error { return nil } +// openURLIncognito opens a URL in incognito/private browsing mode. +// It first tries to detect the default browser and use its incognito flag. +// Falls back to a chain of known browsers if detection fails. +// +// Parameters: +// - url: The URL to open. +// +// Returns: +// - An error if the URL cannot be opened, otherwise nil. +func openURLIncognito(url string) error { + // First, try to detect and use the default browser + if cmd := tryDefaultBrowserIncognito(url); cmd != nil { + log.Debugf("Using detected default browser: %s %v", cmd.Path, cmd.Args[1:]) + if err := cmd.Start(); err == nil { + storeBrowserProcess(cmd) + log.Debug("Successfully opened URL in default browser's incognito mode") + return nil + } + log.Debugf("Failed to start default browser, trying fallback chain") + } + + // Fallback to known browser chain + cmd := tryFallbackBrowsersIncognito(url) + if cmd == nil { + log.Warn("No browser with incognito support found, falling back to normal mode") + return openURLPlatformSpecific(url) + } + + log.Debugf("Running incognito command: %s %v", cmd.Path, cmd.Args[1:]) + err := cmd.Start() + if err != nil { + log.Warnf("Failed to open incognito browser: %v, falling back to normal mode", err) + return openURLPlatformSpecific(url) + } + + storeBrowserProcess(cmd) + log.Debug("Successfully opened URL in incognito/private mode") + return nil +} + +// storeBrowserProcess safely stores the browser process for later cleanup. +func storeBrowserProcess(cmd *exec.Cmd) { + browserMutex.Lock() + lastBrowserProcess = cmd + browserMutex.Unlock() +} + +// tryDefaultBrowserIncognito attempts to detect the default browser and return +// an exec.Cmd configured with the appropriate incognito flag. +func tryDefaultBrowserIncognito(url string) *exec.Cmd { + switch runtime.GOOS { + case "darwin": + return tryDefaultBrowserMacOS(url) + case "windows": + return tryDefaultBrowserWindows(url) + case "linux": + return tryDefaultBrowserLinux(url) + } + return nil +} + +// tryDefaultBrowserMacOS detects the default browser on macOS. +func tryDefaultBrowserMacOS(url string) *exec.Cmd { + // Try to get default browser from Launch Services + out, err := exec.Command("defaults", "read", "com.apple.LaunchServices/com.apple.launchservices.secure", "LSHandlers").Output() + if err != nil { + return nil + } + + output := string(out) + var browserName string + + // Parse the output to find the http/https handler + if containsBrowserID(output, "com.google.chrome") { + browserName = "chrome" + } else if containsBrowserID(output, "org.mozilla.firefox") { + browserName = "firefox" + } else if containsBrowserID(output, "com.apple.safari") { + browserName = "safari" + } else if containsBrowserID(output, "com.brave.browser") { + browserName = "brave" + } else if containsBrowserID(output, "com.microsoft.edgemac") { + browserName = "edge" + } + + return createMacOSIncognitoCmd(browserName, url) +} + +// containsBrowserID checks if the LaunchServices output contains a browser ID. +func containsBrowserID(output, bundleID string) bool { + return strings.Contains(output, bundleID) +} + +// createMacOSIncognitoCmd creates the appropriate incognito command for macOS browsers. +func createMacOSIncognitoCmd(browserName, url string) *exec.Cmd { + switch browserName { + case "chrome": + // Try direct path first + chromePath := "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome" + if _, err := exec.LookPath(chromePath); err == nil { + return exec.Command(chromePath, "--incognito", url) + } + return exec.Command("open", "-na", "Google Chrome", "--args", "--incognito", url) + case "firefox": + return exec.Command("open", "-na", "Firefox", "--args", "--private-window", url) + case "safari": + // Safari doesn't have CLI incognito, try AppleScript + return tryAppleScriptSafariPrivate(url) + case "brave": + return exec.Command("open", "-na", "Brave Browser", "--args", "--incognito", url) + case "edge": + return exec.Command("open", "-na", "Microsoft Edge", "--args", "--inprivate", url) + } + return nil +} + +// tryAppleScriptSafariPrivate attempts to open Safari in private browsing mode using AppleScript. +func tryAppleScriptSafariPrivate(url string) *exec.Cmd { + // AppleScript to open a new private window in Safari + script := fmt.Sprintf(` + tell application "Safari" + activate + tell application "System Events" + keystroke "n" using {command down, shift down} + delay 0.5 + end tell + set URL of document 1 to "%s" + end tell + `, url) + + cmd := exec.Command("osascript", "-e", script) + // Test if this approach works by checking if Safari is available + if _, err := exec.LookPath("/Applications/Safari.app/Contents/MacOS/Safari"); err != nil { + log.Debug("Safari not found, AppleScript private window not available") + return nil + } + log.Debug("Attempting Safari private window via AppleScript") + return cmd +} + +// tryDefaultBrowserWindows detects the default browser on Windows via registry. +func tryDefaultBrowserWindows(url string) *exec.Cmd { + // Query registry for default browser + out, err := exec.Command("reg", "query", + `HKEY_CURRENT_USER\Software\Microsoft\Windows\Shell\Associations\UrlAssociations\http\UserChoice`, + "/v", "ProgId").Output() + if err != nil { + return nil + } + + output := string(out) + var browserName string + + // Map ProgId to browser name + if strings.Contains(output, "ChromeHTML") { + browserName = "chrome" + } else if strings.Contains(output, "FirefoxURL") { + browserName = "firefox" + } else if strings.Contains(output, "MSEdgeHTM") { + browserName = "edge" + } else if strings.Contains(output, "BraveHTML") { + browserName = "brave" + } + + return createWindowsIncognitoCmd(browserName, url) +} + +// createWindowsIncognitoCmd creates the appropriate incognito command for Windows browsers. +func createWindowsIncognitoCmd(browserName, url string) *exec.Cmd { + switch browserName { + case "chrome": + paths := []string{ + "chrome", + `C:\Program Files\Google\Chrome\Application\chrome.exe`, + `C:\Program Files (x86)\Google\Chrome\Application\chrome.exe`, + } + for _, p := range paths { + if _, err := exec.LookPath(p); err == nil { + return exec.Command(p, "--incognito", url) + } + } + case "firefox": + if path, err := exec.LookPath("firefox"); err == nil { + return exec.Command(path, "--private-window", url) + } + case "edge": + paths := []string{ + "msedge", + `C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe`, + `C:\Program Files\Microsoft\Edge\Application\msedge.exe`, + } + for _, p := range paths { + if _, err := exec.LookPath(p); err == nil { + return exec.Command(p, "--inprivate", url) + } + } + case "brave": + paths := []string{ + `C:\Program Files\BraveSoftware\Brave-Browser\Application\brave.exe`, + `C:\Program Files (x86)\BraveSoftware\Brave-Browser\Application\brave.exe`, + } + for _, p := range paths { + if _, err := exec.LookPath(p); err == nil { + return exec.Command(p, "--incognito", url) + } + } + } + return nil +} + +// tryDefaultBrowserLinux detects the default browser on Linux using xdg-settings. +func tryDefaultBrowserLinux(url string) *exec.Cmd { + out, err := exec.Command("xdg-settings", "get", "default-web-browser").Output() + if err != nil { + return nil + } + + desktop := string(out) + var browserName string + + // Map .desktop file to browser name + if strings.Contains(desktop, "google-chrome") || strings.Contains(desktop, "chrome") { + browserName = "chrome" + } else if strings.Contains(desktop, "firefox") { + browserName = "firefox" + } else if strings.Contains(desktop, "chromium") { + browserName = "chromium" + } else if strings.Contains(desktop, "brave") { + browserName = "brave" + } else if strings.Contains(desktop, "microsoft-edge") || strings.Contains(desktop, "msedge") { + browserName = "edge" + } + + return createLinuxIncognitoCmd(browserName, url) +} + +// createLinuxIncognitoCmd creates the appropriate incognito command for Linux browsers. +func createLinuxIncognitoCmd(browserName, url string) *exec.Cmd { + switch browserName { + case "chrome": + paths := []string{"google-chrome", "google-chrome-stable"} + for _, p := range paths { + if path, err := exec.LookPath(p); err == nil { + return exec.Command(path, "--incognito", url) + } + } + case "firefox": + paths := []string{"firefox", "firefox-esr"} + for _, p := range paths { + if path, err := exec.LookPath(p); err == nil { + return exec.Command(path, "--private-window", url) + } + } + case "chromium": + paths := []string{"chromium", "chromium-browser"} + for _, p := range paths { + if path, err := exec.LookPath(p); err == nil { + return exec.Command(path, "--incognito", url) + } + } + case "brave": + if path, err := exec.LookPath("brave-browser"); err == nil { + return exec.Command(path, "--incognito", url) + } + case "edge": + if path, err := exec.LookPath("microsoft-edge"); err == nil { + return exec.Command(path, "--inprivate", url) + } + } + return nil +} + +// tryFallbackBrowsersIncognito tries a chain of known browsers as fallback. +func tryFallbackBrowsersIncognito(url string) *exec.Cmd { + switch runtime.GOOS { + case "darwin": + return tryFallbackBrowsersMacOS(url) + case "windows": + return tryFallbackBrowsersWindows(url) + case "linux": + return tryFallbackBrowsersLinuxChain(url) + } + return nil +} + +// tryFallbackBrowsersMacOS tries known browsers on macOS. +func tryFallbackBrowsersMacOS(url string) *exec.Cmd { + // Try Chrome + chromePath := "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome" + if _, err := exec.LookPath(chromePath); err == nil { + return exec.Command(chromePath, "--incognito", url) + } + // Try Firefox + if _, err := exec.LookPath("/Applications/Firefox.app/Contents/MacOS/firefox"); err == nil { + return exec.Command("open", "-na", "Firefox", "--args", "--private-window", url) + } + // Try Brave + if _, err := exec.LookPath("/Applications/Brave Browser.app/Contents/MacOS/Brave Browser"); err == nil { + return exec.Command("open", "-na", "Brave Browser", "--args", "--incognito", url) + } + // Try Edge + if _, err := exec.LookPath("/Applications/Microsoft Edge.app/Contents/MacOS/Microsoft Edge"); err == nil { + return exec.Command("open", "-na", "Microsoft Edge", "--args", "--inprivate", url) + } + // Last resort: try Safari with AppleScript + if cmd := tryAppleScriptSafariPrivate(url); cmd != nil { + log.Info("Using Safari with AppleScript for private browsing (may require accessibility permissions)") + return cmd + } + return nil +} + +// tryFallbackBrowsersWindows tries known browsers on Windows. +func tryFallbackBrowsersWindows(url string) *exec.Cmd { + // Chrome + chromePaths := []string{ + "chrome", + `C:\Program Files\Google\Chrome\Application\chrome.exe`, + `C:\Program Files (x86)\Google\Chrome\Application\chrome.exe`, + } + for _, p := range chromePaths { + if _, err := exec.LookPath(p); err == nil { + return exec.Command(p, "--incognito", url) + } + } + // Firefox + if path, err := exec.LookPath("firefox"); err == nil { + return exec.Command(path, "--private-window", url) + } + // Edge (usually available on Windows 10+) + edgePaths := []string{ + "msedge", + `C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe`, + `C:\Program Files\Microsoft\Edge\Application\msedge.exe`, + } + for _, p := range edgePaths { + if _, err := exec.LookPath(p); err == nil { + return exec.Command(p, "--inprivate", url) + } + } + return nil +} + +// tryFallbackBrowsersLinuxChain tries known browsers on Linux. +func tryFallbackBrowsersLinuxChain(url string) *exec.Cmd { + type browserConfig struct { + name string + flag string + } + browsers := []browserConfig{ + {"google-chrome", "--incognito"}, + {"google-chrome-stable", "--incognito"}, + {"chromium", "--incognito"}, + {"chromium-browser", "--incognito"}, + {"firefox", "--private-window"}, + {"firefox-esr", "--private-window"}, + {"brave-browser", "--incognito"}, + {"microsoft-edge", "--inprivate"}, + } + for _, b := range browsers { + if path, err := exec.LookPath(b.name); err == nil { + return exec.Command(path, b.flag, url) + } + } + return nil +} + // IsAvailable checks if the system has a command available to open a web browser. // It verifies the presence of necessary commands for the current operating system. // // Returns: // - true if a browser can be opened, false otherwise. func IsAvailable() bool { - // First check if open-golang can work - testErr := open.Run("about:blank") - if testErr == nil { - return true - } - // Check platform-specific commands switch runtime.GOOS { case "darwin": diff --git a/internal/cmd/auth_manager.go b/internal/cmd/auth_manager.go index e6caa954..84d9b969 100644 --- a/internal/cmd/auth_manager.go +++ b/internal/cmd/auth_manager.go @@ -6,7 +6,7 @@ import ( // newAuthManager creates a new authentication manager instance with all supported // authenticators and a file-based token store. It initializes authenticators for -// Gemini, Codex, Claude, and Qwen providers. +// Gemini, Codex, Claude, Qwen, IFlow, Antigravity, and GitHub Copilot providers. // // Returns: // - *sdkAuth.Manager: A configured authentication manager instance @@ -19,6 +19,8 @@ func newAuthManager() *sdkAuth.Manager { sdkAuth.NewQwenAuthenticator(), sdkAuth.NewIFlowAuthenticator(), sdkAuth.NewAntigravityAuthenticator(), + sdkAuth.NewKiroAuthenticator(), + sdkAuth.NewGitHubCopilotAuthenticator(), ) return manager } diff --git a/internal/cmd/github_copilot_login.go b/internal/cmd/github_copilot_login.go new file mode 100644 index 00000000..056e811f --- /dev/null +++ b/internal/cmd/github_copilot_login.go @@ -0,0 +1,44 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + log "github.com/sirupsen/logrus" +) + +// DoGitHubCopilotLogin triggers the OAuth device flow for GitHub Copilot and saves tokens. +// It initiates the device flow authentication, displays the user code for the user to enter +// at GitHub's verification URL, and waits for authorization before saving the tokens. +// +// Parameters: +// - cfg: The application configuration containing proxy and auth directory settings +// - options: Login options including browser behavior settings +func DoGitHubCopilotLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + Prompt: options.Prompt, + } + + record, savedPath, err := manager.Login(context.Background(), "github-copilot", cfg, authOpts) + if err != nil { + log.Errorf("GitHub Copilot authentication failed: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("GitHub Copilot authentication successful!") +} diff --git a/internal/cmd/kiro_login.go b/internal/cmd/kiro_login.go new file mode 100644 index 00000000..5fc3b9eb --- /dev/null +++ b/internal/cmd/kiro_login.go @@ -0,0 +1,160 @@ +package cmd + +import ( + "context" + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + log "github.com/sirupsen/logrus" +) + +// DoKiroLogin triggers the Kiro authentication flow with Google OAuth. +// This is the default login method (same as --kiro-google-login). +// +// Parameters: +// - cfg: The application configuration +// - options: Login options including Prompt field +func DoKiroLogin(cfg *config.Config, options *LoginOptions) { + // Use Google login as default + DoKiroGoogleLogin(cfg, options) +} + +// DoKiroGoogleLogin triggers Kiro authentication with Google OAuth. +// This uses a custom protocol handler (kiro://) to receive the callback. +// +// Parameters: +// - cfg: The application configuration +// - options: Login options including prompts +func DoKiroGoogleLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + // Note: Kiro defaults to incognito mode for multi-account support. + // Users can override with --no-incognito if they want to use existing browser sessions. + + manager := newAuthManager() + + // Use KiroAuthenticator with Google login + authenticator := sdkAuth.NewKiroAuthenticator() + record, err := authenticator.LoginWithGoogle(context.Background(), cfg, &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + Prompt: options.Prompt, + }) + if err != nil { + log.Errorf("Kiro Google authentication failed: %v", err) + fmt.Println("\nTroubleshooting:") + fmt.Println("1. Make sure the protocol handler is installed") + fmt.Println("2. Complete the Google login in the browser") + fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)") + return + } + + // Save the auth record + savedPath, err := manager.SaveAuth(record, cfg) + if err != nil { + log.Errorf("Failed to save auth: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("Kiro Google authentication successful!") +} + +// DoKiroAWSLogin triggers Kiro authentication with AWS Builder ID. +// This uses the device code flow for AWS SSO OIDC authentication. +// +// Parameters: +// - cfg: The application configuration +// - options: Login options including prompts +func DoKiroAWSLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + // Note: Kiro defaults to incognito mode for multi-account support. + // Users can override with --no-incognito if they want to use existing browser sessions. + + manager := newAuthManager() + + // Use KiroAuthenticator with AWS Builder ID login (device code flow) + authenticator := sdkAuth.NewKiroAuthenticator() + record, err := authenticator.Login(context.Background(), cfg, &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + Metadata: map[string]string{}, + Prompt: options.Prompt, + }) + if err != nil { + log.Errorf("Kiro AWS authentication failed: %v", err) + fmt.Println("\nTroubleshooting:") + fmt.Println("1. Make sure you have an AWS Builder ID") + fmt.Println("2. Complete the authorization in the browser") + fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)") + return + } + + // Save the auth record + savedPath, err := manager.SaveAuth(record, cfg) + if err != nil { + log.Errorf("Failed to save auth: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Authenticated as %s\n", record.Label) + } + fmt.Println("Kiro AWS authentication successful!") +} + +// DoKiroImport imports Kiro token from Kiro IDE's token file. +// This is useful for users who have already logged in via Kiro IDE +// and want to use the same credentials in CLI Proxy API. +// +// Parameters: +// - cfg: The application configuration +// - options: Login options (currently unused for import) +func DoKiroImport(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + manager := newAuthManager() + + // Use ImportFromKiroIDE instead of Login + authenticator := sdkAuth.NewKiroAuthenticator() + record, err := authenticator.ImportFromKiroIDE(context.Background(), cfg) + if err != nil { + log.Errorf("Kiro token import failed: %v", err) + fmt.Println("\nMake sure you have logged in to Kiro IDE first:") + fmt.Println("1. Open Kiro IDE") + fmt.Println("2. Click 'Sign in with Google' (or GitHub)") + fmt.Println("3. Complete the login process") + fmt.Println("4. Run this command again") + return + } + + // Save the imported auth record + savedPath, err := manager.SaveAuth(record, cfg) + if err != nil { + log.Errorf("Failed to save auth: %v", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + if record != nil && record.Label != "" { + fmt.Printf("Imported as %s\n", record.Label) + } + fmt.Println("Kiro token import successful!") +} diff --git a/internal/config/config.go b/internal/config/config.go index 2310d7c2..3056fc4b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -63,6 +63,13 @@ type Config struct { // GeminiKey defines Gemini API key configurations with optional routing overrides. GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"` + // KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations. + KiroKey []KiroKey `yaml:"kiro" json:"kiro"` + + // KiroPreferredEndpoint sets the global default preferred endpoint for all Kiro providers. + // Values: "ide" (default, CodeWhisperer) or "cli" (Amazon Q). + KiroPreferredEndpoint string `yaml:"kiro-preferred-endpoint" json:"kiro-preferred-endpoint"` + // Codex defines a list of Codex API key configurations as specified in the YAML configuration file. CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"` @@ -85,6 +92,11 @@ type Config struct { // Payload defines default and override rules for provider payload parameters. Payload PayloadConfig `yaml:"payload" json:"payload"` + // IncognitoBrowser enables opening OAuth URLs in incognito/private browsing mode. + // This is useful when you want to login with a different account without logging out + // from your current session. Default: false. + IncognitoBrowser bool `yaml:"incognito-browser" json:"incognito-browser"` + legacyMigrationPending bool `yaml:"-" json:"-"` } @@ -252,6 +264,35 @@ type GeminiKey struct { ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` } +// KiroKey represents the configuration for Kiro (AWS CodeWhisperer) authentication. +type KiroKey struct { + // TokenFile is the path to the Kiro token file (default: ~/.aws/sso/cache/kiro-auth-token.json) + TokenFile string `yaml:"token-file,omitempty" json:"token-file,omitempty"` + + // AccessToken is the OAuth access token for direct configuration. + AccessToken string `yaml:"access-token,omitempty" json:"access-token,omitempty"` + + // RefreshToken is the OAuth refresh token for token renewal. + RefreshToken string `yaml:"refresh-token,omitempty" json:"refresh-token,omitempty"` + + // ProfileArn is the AWS CodeWhisperer profile ARN. + ProfileArn string `yaml:"profile-arn,omitempty" json:"profile-arn,omitempty"` + + // Region is the AWS region (default: us-east-1). + Region string `yaml:"region,omitempty" json:"region,omitempty"` + + // ProxyURL optionally overrides the global proxy for this configuration. + ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` + + // AgentTaskType sets the Kiro API task type. Known values: "vibe", "dev", "chat". + // Leave empty to let API use defaults. Different values may inject different system prompts. + AgentTaskType string `yaml:"agent-task-type,omitempty" json:"agent-task-type,omitempty"` + + // PreferredEndpoint sets the preferred Kiro API endpoint/quota. + // Values: "codewhisperer" (default, IDE quota) or "amazonq" (CLI quota). + PreferredEndpoint string `yaml:"preferred-endpoint,omitempty" json:"preferred-endpoint,omitempty"` +} + // OpenAICompatibility represents the configuration for OpenAI API compatibility // with external providers, allowing model aliases to be routed through OpenAI API format. type OpenAICompatibility struct { @@ -334,6 +375,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { cfg.DisableCooling = false cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository + cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force) if err = yaml.Unmarshal(data, &cfg); err != nil { if optional { // In cloud deploy mode, if YAML parsing fails, return empty config instead of error. @@ -389,6 +431,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Sanitize Claude key headers cfg.SanitizeClaudeKeys() + // Sanitize Kiro keys: trim whitespace from credential fields + cfg.SanitizeKiroKeys() + // Sanitize OpenAI compatibility providers: drop entries without base-url cfg.SanitizeOpenAICompatibility() @@ -465,6 +510,23 @@ func (cfg *Config) SanitizeClaudeKeys() { } } +// SanitizeKiroKeys trims whitespace from Kiro credential fields. +func (cfg *Config) SanitizeKiroKeys() { + if cfg == nil || len(cfg.KiroKey) == 0 { + return + } + for i := range cfg.KiroKey { + entry := &cfg.KiroKey[i] + entry.TokenFile = strings.TrimSpace(entry.TokenFile) + entry.AccessToken = strings.TrimSpace(entry.AccessToken) + entry.RefreshToken = strings.TrimSpace(entry.RefreshToken) + entry.ProfileArn = strings.TrimSpace(entry.ProfileArn) + entry.Region = strings.TrimSpace(entry.Region) + entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) + entry.PreferredEndpoint = strings.TrimSpace(entry.PreferredEndpoint) + } +} + // SanitizeGeminiKeys deduplicates and normalizes Gemini credentials. func (cfg *Config) SanitizeGeminiKeys() { if cfg == nil { diff --git a/internal/constant/constant.go b/internal/constant/constant.go index 58b388a1..1dbeecde 100644 --- a/internal/constant/constant.go +++ b/internal/constant/constant.go @@ -24,4 +24,7 @@ const ( // Antigravity represents the Antigravity response format identifier. Antigravity = "antigravity" + + // Kiro represents the AWS CodeWhisperer (Kiro) provider identifier. + Kiro = "kiro" ) diff --git a/internal/logging/global_logger.go b/internal/logging/global_logger.go index 28fde213..74a79efc 100644 --- a/internal/logging/global_logger.go +++ b/internal/logging/global_logger.go @@ -38,13 +38,16 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) { timestamp := entry.Time.Format("2006-01-02 15:04:05") message := strings.TrimRight(entry.Message, "\r\n") - - var formatted string + + // Handle nil Caller (can happen with some log entries) + callerFile := "unknown" + callerLine := 0 if entry.Caller != nil { - formatted = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, filepath.Base(entry.Caller.File), entry.Caller.Line, message) - } else { - formatted = fmt.Sprintf("[%s] [%s] %s\n", timestamp, entry.Level, message) + callerFile = filepath.Base(entry.Caller.File) + callerLine = entry.Caller.Line } + + formatted := fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, callerFile, callerLine, message) buffer.WriteString(formatted) return buffer.Bytes(), nil @@ -55,6 +58,7 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) { func SetupBaseLogger() { setupOnce.Do(func() { log.SetOutput(os.Stdout) + log.SetLevel(log.InfoLevel) log.SetReportCaller(true) log.SetFormatter(&LogFormatter{}) diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 7a4bdf0c..f9310470 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -697,3 +697,353 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig { "gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000}, } } + +// GetGitHubCopilotModels returns the available models for GitHub Copilot. +// These models are available through the GitHub Copilot API at api.githubcopilot.com. +func GetGitHubCopilotModels() []*ModelInfo { + now := int64(1732752000) // 2024-11-27 + return []*ModelInfo{ + { + ID: "gpt-4.1", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-4.1", + Description: "OpenAI GPT-4.1 via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + { + ID: "gpt-5", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5", + Description: "OpenAI GPT-5 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + }, + { + ID: "gpt-5-mini", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5 Mini", + Description: "OpenAI GPT-5 Mini via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + { + ID: "gpt-5-codex", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5 Codex", + Description: "OpenAI GPT-5 Codex via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + }, + { + ID: "gpt-5.1", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.1", + Description: "OpenAI GPT-5.1 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + }, + { + ID: "gpt-5.1-codex", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.1 Codex", + Description: "OpenAI GPT-5.1 Codex via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32768, + }, + { + ID: "gpt-5.1-codex-mini", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "GPT-5.1 Codex Mini", + Description: "OpenAI GPT-5.1 Codex Mini via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + { + ID: "claude-haiku-4.5", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Haiku 4.5", + Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "claude-opus-4.1", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Opus 4.1", + Description: "Anthropic Claude Opus 4.1 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 32000, + }, + { + ID: "claude-opus-4.5", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Opus 4.5", + Description: "Anthropic Claude Opus 4.5 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "claude-sonnet-4", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Sonnet 4", + Description: "Anthropic Claude Sonnet 4 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "claude-sonnet-4.5", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Claude Sonnet 4.5", + Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "gemini-2.5-pro", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Gemini 2.5 Pro", + Description: "Google Gemini 2.5 Pro via GitHub Copilot", + ContextLength: 1048576, + MaxCompletionTokens: 65536, + }, + { + ID: "gemini-3-pro", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Gemini 3 Pro", + Description: "Google Gemini 3 Pro via GitHub Copilot", + ContextLength: 1048576, + MaxCompletionTokens: 65536, + }, + { + ID: "grok-code-fast-1", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Grok Code Fast 1", + Description: "xAI Grok Code Fast 1 via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + { + ID: "raptor-mini", + Object: "model", + Created: now, + OwnedBy: "github-copilot", + Type: "github-copilot", + DisplayName: "Raptor Mini", + Description: "Raptor Mini via GitHub Copilot", + ContextLength: 128000, + MaxCompletionTokens: 16384, + }, + } +} + +// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions +func GetKiroModels() []*ModelInfo { + return []*ModelInfo{ + // --- Base Models --- + { + ID: "kiro-claude-opus-4-5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Opus 4.5", + Description: "Claude Opus 4.5 via Kiro (2.2x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-sonnet-4-5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4.5", + Description: "Claude Sonnet 4.5 via Kiro (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-sonnet-4", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4", + Description: "Claude Sonnet 4 via Kiro (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-haiku-4-5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Haiku 4.5", + Description: "Claude Haiku 4.5 via Kiro (0.4x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + // --- Agentic Variants (Optimized for coding agents with chunked writes) --- + { + ID: "kiro-claude-opus-4-5-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Opus 4.5 (Agentic)", + Description: "Claude Opus 4.5 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-sonnet-4-5-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4.5 (Agentic)", + Description: "Claude Sonnet 4.5 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-sonnet-4-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Sonnet 4 (Agentic)", + Description: "Claude Sonnet 4 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + { + ID: "kiro-claude-haiku-4-5-agentic", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Kiro Claude Haiku 4.5 (Agentic)", + Description: "Claude Haiku 4.5 optimized for coding agents (chunked writes)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, + }, + } +} + +// GetAmazonQModels returns the Amazon Q (AWS CodeWhisperer) model definitions. +// These models use the same API as Kiro and share the same executor. +func GetAmazonQModels() []*ModelInfo { + return []*ModelInfo{ + { + ID: "amazonq-auto", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", // Uses Kiro executor - same API + DisplayName: "Amazon Q Auto", + Description: "Automatic model selection by Amazon Q", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "amazonq-claude-opus-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Amazon Q Claude Opus 4.5", + Description: "Claude Opus 4.5 via Amazon Q (2.2x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "amazonq-claude-sonnet-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Amazon Q Claude Sonnet 4.5", + Description: "Claude Sonnet 4.5 via Amazon Q (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "amazonq-claude-sonnet-4", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Amazon Q Claude Sonnet 4", + Description: "Claude Sonnet 4 via Amazon Q (1.3x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + { + ID: "amazonq-claude-haiku-4.5", + Object: "model", + Created: 1732752000, + OwnedBy: "aws", + Type: "kiro", + DisplayName: "Amazon Q Claude Haiku 4.5", + Description: "Claude Haiku 4.5 via Amazon Q (0.4x credit)", + ContextLength: 200000, + MaxCompletionTokens: 64000, + }, + } +} diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index d4f84481..a34cf742 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -766,7 +766,8 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) } return result - case "claude": + case "claude", "kiro", "antigravity": + // Claude, Kiro, and Antigravity all use Claude-compatible format for Claude Code client result := map[string]any{ "id": model.ID, "object": "model", @@ -781,6 +782,19 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) if model.DisplayName != "" { result["display_name"] = model.DisplayName } + // Add thinking support for Claude Code client + // Claude Code checks for "thinking" field (simple boolean) to enable tab toggle + // Also add "extended_thinking" for detailed budget info + if model.Thinking != nil { + result["thinking"] = true + result["extended_thinking"] = map[string]any{ + "supported": true, + "min": model.Thinking.Min, + "max": model.Thinking.Max, + "zero_allowed": model.Thinking.ZeroAllowed, + "dynamic_allowed": model.Thinking.DynamicAllowed, + } + } return result case "gemini": diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 5884f705..a77bc037 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -15,6 +15,7 @@ import ( "net/url" "strconv" "strings" + "sync" "time" "github.com/google/uuid" @@ -43,7 +44,10 @@ const ( refreshSkew = 3000 * time.Second ) -var randSource = rand.New(rand.NewSource(time.Now().UnixNano())) +var ( + randSource = rand.New(rand.NewSource(time.Now().UnixNano())) + randSourceMutex sync.Mutex +) // AntigravityExecutor proxies requests to the antigravity upstream. type AntigravityExecutor struct { @@ -777,15 +781,19 @@ func generateRequestID() string { } func generateSessionID() string { + randSourceMutex.Lock() n := randSource.Int63n(9_000_000_000_000_000_000) + randSourceMutex.Unlock() return "-" + strconv.FormatInt(n, 10) } func generateProjectID() string { adjectives := []string{"useful", "bright", "swift", "calm", "bold"} nouns := []string{"fuze", "wave", "spark", "flow", "core"} + randSourceMutex.Lock() adj := adjectives[randSource.Intn(len(adjectives))] noun := nouns[randSource.Intn(len(nouns))] + randSourceMutex.Unlock() randomPart := strings.ToLower(uuid.NewString())[:5] return adj + "-" + noun + "-" + randomPart } diff --git a/internal/runtime/executor/cache_helpers.go b/internal/runtime/executor/cache_helpers.go index 5272686b..4b553662 100644 --- a/internal/runtime/executor/cache_helpers.go +++ b/internal/runtime/executor/cache_helpers.go @@ -1,10 +1,38 @@ package executor -import "time" +import ( + "sync" + "time" +) type codexCache struct { ID string Expire time.Time } -var codexCacheMap = map[string]codexCache{} +var ( + codexCacheMap = map[string]codexCache{} + codexCacheMutex sync.RWMutex +) + +// getCodexCache safely retrieves a cache entry +func getCodexCache(key string) (codexCache, bool) { + codexCacheMutex.RLock() + defer codexCacheMutex.RUnlock() + cache, ok := codexCacheMap[key] + return cache, ok +} + +// setCodexCache safely sets a cache entry +func setCodexCache(key string, cache codexCache) { + codexCacheMutex.Lock() + defer codexCacheMutex.Unlock() + codexCacheMap[key] = cache +} + +// deleteCodexCache safely deletes a cache entry +func deleteCodexCache(key string) { + codexCacheMutex.Lock() + defer codexCacheMutex.Unlock() + delete(codexCacheMap, key) +} diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index c3e14701..4caf05c4 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -442,12 +442,12 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form if userIDResult.Exists() { var hasKey bool key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String()) - if cache, hasKey = codexCacheMap[key]; !hasKey || cache.Expire.Before(time.Now()) { + if cache, hasKey = getCodexCache(key); !hasKey || cache.Expire.Before(time.Now()) { cache = codexCache{ ID: uuid.New().String(), Expire: time.Now().Add(1 * time.Hour), } - codexCacheMap[key] = cache + setCodexCache(key, cache) } } } else if from == "openai-response" { diff --git a/internal/runtime/executor/github_copilot_executor.go b/internal/runtime/executor/github_copilot_executor.go new file mode 100644 index 00000000..b2bc73df --- /dev/null +++ b/internal/runtime/executor/github_copilot_executor.go @@ -0,0 +1,361 @@ +package executor + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "net/http" + "sync" + "time" + + "github.com/google/uuid" + copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + log "github.com/sirupsen/logrus" + "github.com/tidwall/sjson" +) + +const ( + githubCopilotBaseURL = "https://api.githubcopilot.com" + githubCopilotChatPath = "/chat/completions" + githubCopilotAuthType = "github-copilot" + githubCopilotTokenCacheTTL = 25 * time.Minute + // tokenExpiryBuffer is the time before expiry when we should refresh the token. + tokenExpiryBuffer = 5 * time.Minute + // maxScannerBufferSize is the maximum buffer size for SSE scanning (20MB). + maxScannerBufferSize = 20_971_520 + + // Copilot API header values. + copilotUserAgent = "GithubCopilot/1.0" + copilotEditorVersion = "vscode/1.100.0" + copilotPluginVersion = "copilot/1.300.0" + copilotIntegrationID = "vscode-chat" + copilotOpenAIIntent = "conversation-panel" +) + +// GitHubCopilotExecutor handles requests to the GitHub Copilot API. +type GitHubCopilotExecutor struct { + cfg *config.Config + mu sync.RWMutex + cache map[string]*cachedAPIToken +} + +// cachedAPIToken stores a cached Copilot API token with its expiry. +type cachedAPIToken struct { + token string + expiresAt time.Time +} + +// NewGitHubCopilotExecutor constructs a new executor instance. +func NewGitHubCopilotExecutor(cfg *config.Config) *GitHubCopilotExecutor { + return &GitHubCopilotExecutor{ + cfg: cfg, + cache: make(map[string]*cachedAPIToken), + } +} + +// Identifier implements ProviderExecutor. +func (e *GitHubCopilotExecutor) Identifier() string { return githubCopilotAuthType } + +// PrepareRequest implements ProviderExecutor. +func (e *GitHubCopilotExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { + return nil +} + +// Execute handles non-streaming requests to GitHub Copilot. +func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + apiToken, errToken := e.ensureAPIToken(ctx, auth) + if errToken != nil { + return resp, errToken + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("openai") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + body = e.normalizeModel(req.Model, body) + body = applyPayloadConfig(e.cfg, req.Model, body) + body, _ = sjson.SetBytes(body, "stream", false) + + url := githubCopilotBaseURL + githubCopilotChatPath + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return resp, err + } + e.applyHeaders(httpReq, apiToken) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("github-copilot executor: close response body error: %v", errClose) + } + }() + + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + if !isHTTPSuccess(httpResp.StatusCode) { + data, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, data) + log.Debugf("github-copilot executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + err = statusErr{code: httpResp.StatusCode, msg: string(data)} + return resp, err + } + + data, err := io.ReadAll(httpResp.Body) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + appendAPIResponseChunk(ctx, e.cfg, data) + + detail := parseOpenAIUsage(data) + if detail.TotalTokens > 0 { + reporter.publish(ctx, detail) + } + + var param any + converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + resp = cliproxyexecutor.Response{Payload: []byte(converted)} + reporter.ensurePublished(ctx) + return resp, nil +} + +// ExecuteStream handles streaming requests to GitHub Copilot. +func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + apiToken, errToken := e.ensureAPIToken(ctx, auth) + if errToken != nil { + return nil, errToken + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("openai") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + body = e.normalizeModel(req.Model, body) + body = applyPayloadConfig(e.cfg, req.Model, body) + body, _ = sjson.SetBytes(body, "stream", true) + // Enable stream options for usage stats in stream + body, _ = sjson.SetBytes(body, "stream_options.include_usage", true) + + url := githubCopilotBaseURL + githubCopilotChatPath + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + e.applyHeaders(httpReq, apiToken) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return nil, err + } + + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + if !isHTTPSuccess(httpResp.StatusCode) { + data, readErr := io.ReadAll(httpResp.Body) + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("github-copilot executor: close response body error: %v", errClose) + } + if readErr != nil { + recordAPIResponseError(ctx, e.cfg, readErr) + return nil, readErr + } + appendAPIResponseChunk(ctx, e.cfg, data) + log.Debugf("github-copilot executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + err = statusErr{code: httpResp.StatusCode, msg: string(data)} + return nil, err + } + + out := make(chan cliproxyexecutor.StreamChunk) + stream = out + + go func() { + defer close(out) + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("github-copilot executor: close response body error: %v", errClose) + } + }() + + scanner := bufio.NewScanner(httpResp.Body) + scanner.Buffer(nil, maxScannerBufferSize) + var param any + + for scanner.Scan() { + line := scanner.Bytes() + appendAPIResponseChunk(ctx, e.cfg, line) + + // Parse SSE data + if bytes.HasPrefix(line, dataTag) { + data := bytes.TrimSpace(line[5:]) + if bytes.Equal(data, []byte("[DONE]")) { + continue + } + if detail, ok := parseOpenAIStreamUsage(line); ok { + reporter.publish(ctx, detail) + } + } + + chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + for i := range chunks { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} + } + } + + if errScan := scanner.Err(); errScan != nil { + recordAPIResponseError(ctx, e.cfg, errScan) + reporter.publishFailure(ctx) + out <- cliproxyexecutor.StreamChunk{Err: errScan} + } else { + reporter.ensurePublished(ctx) + } + }() + + return stream, nil +} + +// CountTokens is not supported for GitHub Copilot. +func (e *GitHubCopilotExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported for github-copilot"} +} + +// Refresh validates the GitHub token is still working. +// GitHub OAuth tokens don't expire traditionally, so we just validate. +func (e *GitHubCopilotExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + if auth == nil { + return nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"} + } + + // Get the GitHub access token + accessToken := metaStringValue(auth.Metadata, "access_token") + if accessToken == "" { + return auth, nil + } + + // Validate the token can still get a Copilot API token + copilotAuth := copilotauth.NewCopilotAuth(e.cfg) + _, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken) + if err != nil { + return nil, statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("github-copilot token validation failed: %v", err)} + } + + return auth, nil +} + +// ensureAPIToken gets or refreshes the Copilot API token. +func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *cliproxyauth.Auth) (string, error) { + if auth == nil { + return "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"} + } + + // Get the GitHub access token + accessToken := metaStringValue(auth.Metadata, "access_token") + if accessToken == "" { + return "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"} + } + + // Check for cached API token using thread-safe access + e.mu.RLock() + if cached, ok := e.cache[accessToken]; ok && cached.expiresAt.After(time.Now().Add(tokenExpiryBuffer)) { + e.mu.RUnlock() + return cached.token, nil + } + e.mu.RUnlock() + + // Get a new Copilot API token + copilotAuth := copilotauth.NewCopilotAuth(e.cfg) + apiToken, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken) + if err != nil { + return "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)} + } + + // Cache the token with thread-safe access + expiresAt := time.Now().Add(githubCopilotTokenCacheTTL) + if apiToken.ExpiresAt > 0 { + expiresAt = time.Unix(apiToken.ExpiresAt, 0) + } + e.mu.Lock() + e.cache[accessToken] = &cachedAPIToken{ + token: apiToken.Token, + expiresAt: expiresAt, + } + e.mu.Unlock() + + return apiToken.Token, nil +} + +// applyHeaders sets the required headers for GitHub Copilot API requests. +func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string) { + r.Header.Set("Content-Type", "application/json") + r.Header.Set("Authorization", "Bearer "+apiToken) + r.Header.Set("Accept", "application/json") + r.Header.Set("User-Agent", copilotUserAgent) + r.Header.Set("Editor-Version", copilotEditorVersion) + r.Header.Set("Editor-Plugin-Version", copilotPluginVersion) + r.Header.Set("Openai-Intent", copilotOpenAIIntent) + r.Header.Set("Copilot-Integration-Id", copilotIntegrationID) + r.Header.Set("X-Request-Id", uuid.NewString()) +} + +// normalizeModel is a no-op as GitHub Copilot accepts model names directly. +// Model mapping should be done at the registry level if needed. +func (e *GitHubCopilotExecutor) normalizeModel(_ string, body []byte) []byte { + return body +} + +// isHTTPSuccess checks if the status code indicates success (2xx). +func isHTTPSuccess(statusCode int) bool { + return statusCode >= 200 && statusCode < 300 +} diff --git a/internal/runtime/executor/kiro_executor.go b/internal/runtime/executor/kiro_executor.go new file mode 100644 index 00000000..be6be1ed --- /dev/null +++ b/internal/runtime/executor/kiro_executor.go @@ -0,0 +1,3195 @@ +package executor + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/google/uuid" + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + kiroclaude "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude" + kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common" + kiroopenai "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + log "github.com/sirupsen/logrus" + +) + +const ( + // Kiro API common constants + kiroContentType = "application/x-amz-json-1.0" + kiroAcceptStream = "*/*" + + // Event Stream frame size constants for boundary protection + // AWS Event Stream binary format: prelude (12 bytes) + headers + payload + message_crc (4 bytes) + // Prelude consists of: total_length (4) + headers_length (4) + prelude_crc (4) + minEventStreamFrameSize = 16 // Minimum: 4(total_len) + 4(headers_len) + 4(prelude_crc) + 4(message_crc) + maxEventStreamMsgSize = 10 << 20 // Maximum message length: 10MB + + // Event Stream error type constants + ErrStreamFatal = "fatal" // Connection/authentication errors, not recoverable + ErrStreamMalformed = "malformed" // Format errors, data cannot be parsed + // kiroUserAgent matches amq2api format for User-Agent header + kiroUserAgent = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0" + // kiroFullUserAgent is the complete x-amz-user-agent header matching amq2api + kiroFullUserAgent = "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI" +) + +// Real-time usage estimation configuration +// These control how often usage updates are sent during streaming +var ( + usageUpdateCharThreshold = 5000 // Send usage update every 5000 characters + usageUpdateTimeInterval = 15 * time.Second // Or every 15 seconds, whichever comes first +) + +// kiroEndpointConfig bundles endpoint URL with its compatible Origin and AmzTarget values. +// This solves the "triple mismatch" problem where different endpoints require matching +// Origin and X-Amz-Target header values. +// +// Based on reference implementations: +// - amq2api-main: Uses Amazon Q endpoint with CLI origin and AmazonQDeveloperStreamingService target +// - AIClient-2-API: Uses CodeWhisperer endpoint with AI_EDITOR origin and AmazonCodeWhispererStreamingService target +type kiroEndpointConfig struct { + URL string // Endpoint URL + Origin string // Request Origin: "CLI" for Amazon Q quota, "AI_EDITOR" for Kiro IDE quota + AmzTarget string // X-Amz-Target header value + Name string // Endpoint name for logging +} + +// kiroEndpointConfigs defines the available Kiro API endpoints with their compatible configurations. +// The order determines fallback priority: primary endpoint first, then fallbacks. +// +// CRITICAL: Each endpoint MUST use its compatible Origin and AmzTarget values: +// - CodeWhisperer endpoint (codewhisperer.us-east-1.amazonaws.com): Uses AI_EDITOR origin and AmazonCodeWhispererStreamingService target +// - Amazon Q endpoint (q.us-east-1.amazonaws.com): Uses CLI origin and AmazonQDeveloperStreamingService target +// +// Mismatched combinations will result in 403 Forbidden errors. +// +// NOTE: CodeWhisperer is set as the default endpoint because: +// 1. Most tokens come from Kiro IDE / VSCode extensions (AWS Builder ID auth) +// 2. These tokens use AI_EDITOR origin which is only compatible with CodeWhisperer endpoint +// 3. Amazon Q endpoint requires CLI origin which is for Amazon Q CLI tokens +// This matches the AIClient-2-API-main project's configuration. +var kiroEndpointConfigs = []kiroEndpointConfig{ + { + URL: "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse", + Origin: "AI_EDITOR", + AmzTarget: "AmazonCodeWhispererStreamingService.GenerateAssistantResponse", + Name: "CodeWhisperer", + }, + { + URL: "https://q.us-east-1.amazonaws.com/", + Origin: "CLI", + AmzTarget: "AmazonQDeveloperStreamingService.SendMessage", + Name: "AmazonQ", + }, +} + +// getKiroEndpointConfigs returns the list of Kiro API endpoint configurations to try in order. +// Supports reordering based on "preferred_endpoint" in auth metadata/attributes. +func getKiroEndpointConfigs(auth *cliproxyauth.Auth) []kiroEndpointConfig { + if auth == nil { + return kiroEndpointConfigs + } + + // Check for preference + var preference string + if auth.Metadata != nil { + if p, ok := auth.Metadata["preferred_endpoint"].(string); ok { + preference = p + } + } + // Check attributes as fallback (e.g. from HTTP headers) + if preference == "" && auth.Attributes != nil { + preference = auth.Attributes["preferred_endpoint"] + } + + if preference == "" { + return kiroEndpointConfigs + } + + preference = strings.ToLower(strings.TrimSpace(preference)) + + // Create new slice to avoid modifying global state + var sorted []kiroEndpointConfig + var remaining []kiroEndpointConfig + + for _, cfg := range kiroEndpointConfigs { + name := strings.ToLower(cfg.Name) + // Check for matches + // CodeWhisperer aliases: codewhisperer, ide + // AmazonQ aliases: amazonq, q, cli + isMatch := false + if (preference == "codewhisperer" || preference == "ide") && name == "codewhisperer" { + isMatch = true + } else if (preference == "amazonq" || preference == "q" || preference == "cli") && name == "amazonq" { + isMatch = true + } + + if isMatch { + sorted = append(sorted, cfg) + } else { + remaining = append(remaining, cfg) + } + } + + // If preference didn't match anything, return default + if len(sorted) == 0 { + return kiroEndpointConfigs + } + + // Combine: preferred first, then others + return append(sorted, remaining...) +} + +// KiroExecutor handles requests to AWS CodeWhisperer (Kiro) API. +type KiroExecutor struct { + cfg *config.Config + refreshMu sync.Mutex // Serializes token refresh operations to prevent race conditions +} + +// buildKiroPayloadForFormat builds the Kiro API payload based on the source format. +// This is critical because OpenAI and Claude formats have different tool structures: +// - OpenAI: tools[].function.name, tools[].function.description +// - Claude: tools[].name, tools[].description +// Returns the serialized JSON payload and a boolean indicating whether thinking mode was injected. +func buildKiroPayloadForFormat(body []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool, sourceFormat sdktranslator.Format) ([]byte, bool) { + switch sourceFormat.String() { + case "openai": + log.Debugf("kiro: using OpenAI payload builder for source format: %s", sourceFormat.String()) + return kiroopenai.BuildKiroPayloadFromOpenAI(body, modelID, profileArn, origin, isAgentic, isChatOnly) + default: + // Default to Claude format (also handles "claude", "kiro", etc.) + log.Debugf("kiro: using Claude payload builder for source format: %s", sourceFormat.String()) + return kiroclaude.BuildKiroPayload(body, modelID, profileArn, origin, isAgentic, isChatOnly) + } +} + +// NewKiroExecutor creates a new Kiro executor instance. +func NewKiroExecutor(cfg *config.Config) *KiroExecutor { + return &KiroExecutor{cfg: cfg} +} + +// Identifier returns the unique identifier for this executor. +func (e *KiroExecutor) Identifier() string { return "kiro" } + +// PrepareRequest prepares the HTTP request before execution. +func (e *KiroExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil } + +// Execute sends the request to Kiro API and returns the response. +// Supports automatic token refresh on 401/403 errors. +func (e *KiroExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + accessToken, profileArn := kiroCredentials(auth) + if accessToken == "" { + return resp, fmt.Errorf("kiro: access token not found in auth") + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + // Check if token is expired before making request + if e.isTokenExpired(accessToken) { + log.Infof("kiro: access token expired, attempting refresh before request") + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) + } else if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + log.Infof("kiro: token refreshed successfully before request") + } + } + + from := opts.SourceFormat + to := sdktranslator.FromString("kiro") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + + kiroModelID := e.mapModelToKiro(req.Model) + + // Determine agentic mode and effective profile ARN using helper functions + isAgentic, isChatOnly := determineAgenticMode(req.Model) + effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) + + // Execute with retry on 401/403 and 429 (quota exhausted) + // Note: currentOrigin and kiroPayload are built inside executeWithRetry for each endpoint + resp, err = e.executeWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, to, reporter, "", kiroModelID, isAgentic, isChatOnly) + return resp, err +} + +// executeWithRetry performs the actual HTTP request with automatic retry on auth errors. +// Supports automatic fallback between endpoints with different quotas: +// - Amazon Q endpoint (CLI origin) uses Amazon Q Developer quota +// - CodeWhisperer endpoint (AI_EDITOR origin) uses Kiro IDE quota +// Also supports multi-endpoint fallback similar to Antigravity implementation. +func (e *KiroExecutor) executeWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from, to sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (cliproxyexecutor.Response, error) { + var resp cliproxyexecutor.Response + maxRetries := 2 // Allow retries for token refresh + endpoint fallback + endpointConfigs := getKiroEndpointConfigs(auth) + var last429Err error + + for endpointIdx := 0; endpointIdx < len(endpointConfigs); endpointIdx++ { + endpointConfig := endpointConfigs[endpointIdx] + url := endpointConfig.URL + // Use this endpoint's compatible Origin (critical for avoiding 403 errors) + currentOrigin = endpointConfig.Origin + + // Rebuild payload with the correct origin for this endpoint + // Each endpoint requires its matching Origin value in the request body + kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) + + log.Debugf("kiro: trying endpoint %d/%d: %s (Name: %s, Origin: %s)", + endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin) + + for attempt := 0; attempt <= maxRetries; attempt++ { + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) + if err != nil { + return resp, err + } + + httpReq.Header.Set("Content-Type", kiroContentType) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) + httpReq.Header.Set("Accept", kiroAcceptStream) + // Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors) + httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) + httpReq.Header.Set("User-Agent", kiroUserAgent) + httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") + httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) + + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(httpReq, attrs) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: kiroPayload, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 120*time.Second) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + // Handle 429 errors (quota exhausted) - try next endpoint + // Each endpoint has its own quota pool, so we can try different endpoints + if httpResp.StatusCode == 429 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + // Preserve last 429 so callers can correctly backoff when all endpoints are exhausted + last429Err = statusErr{code: httpResp.StatusCode, msg: string(respBody)} + + log.Warnf("kiro: %s endpoint quota exhausted (429), will try next endpoint, body: %s", + endpointConfig.Name, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + + // Break inner retry loop to try next endpoint (which has different quota) + break + } + + // Handle 5xx server errors with exponential backoff retry + if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + if attempt < maxRetries { + // Exponential backoff: 1s, 2s, 4s... (max 30s) + backoff := time.Duration(1< 30*time.Second { + backoff = 30 * time.Second + } + log.Warnf("kiro: server error %d, retrying in %v (attempt %d/%d)", httpResp.StatusCode, backoff, attempt+1, maxRetries) + time.Sleep(backoff) + continue + } + log.Errorf("kiro: server error %d after %d retries", httpResp.StatusCode, maxRetries) + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 401 errors with token refresh and retry + // 401 = Unauthorized (token expired/invalid) - refresh token + if httpResp.StatusCode == 401 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + if attempt < maxRetries { + log.Warnf("kiro: received 401 error, attempting token refresh and retry (attempt %d/%d)", attempt+1, maxRetries+1) + + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Errorf("kiro: token refresh failed: %v", refreshErr) + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + // Rebuild payload with new profile ARN if changed + kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) + log.Infof("kiro: token refreshed successfully, retrying request") + continue + } + } + + log.Warnf("kiro request error, status: 401, body: %s", summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 402 errors - Monthly Limit Reached + if httpResp.StatusCode == 402 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + log.Warnf("kiro: received 402 (monthly limit). Upstream body: %s", string(respBody)) + + // Return upstream error body directly + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 403 errors - Access Denied / Token Expired + // Do NOT switch endpoints for 403 errors + if httpResp.StatusCode == 403 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + // Log the 403 error details for debugging + log.Warnf("kiro: received 403 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + + respBodyStr := string(respBody) + + // Check for SUSPENDED status - return immediately without retry + if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") { + log.Errorf("kiro: account is suspended, cannot proceed") + return resp, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)} + } + + // Check if this looks like a token-related 403 (some APIs return 403 for expired tokens) + isTokenRelated := strings.Contains(respBodyStr, "token") || + strings.Contains(respBodyStr, "expired") || + strings.Contains(respBodyStr, "invalid") || + strings.Contains(respBodyStr, "unauthorized") + + if isTokenRelated && attempt < maxRetries { + log.Warnf("kiro: 403 appears token-related, attempting token refresh") + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Errorf("kiro: token refresh failed: %v", refreshErr) + // Token refresh failed - return error immediately + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) + log.Infof("kiro: token refreshed for 403, retrying request") + continue + } + } + + // For non-token 403 or after max retries, return error immediately + // Do NOT switch endpoints for 403 errors + log.Warnf("kiro: 403 error, returning immediately (no endpoint switch)") + return resp, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("kiro request error, status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + err = statusErr{code: httpResp.StatusCode, msg: string(b)} + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + return resp, err + } + + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + }() + + content, toolUses, usageInfo, stopReason, err := e.parseEventStream(httpResp.Body) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + + // Fallback for usage if missing from upstream + if usageInfo.TotalTokens == 0 { + if enc, encErr := getTokenizer(req.Model); encErr == nil { + if inp, countErr := countOpenAIChatTokens(enc, opts.OriginalRequest); countErr == nil { + usageInfo.InputTokens = inp + } + } + if len(content) > 0 { + // Use tiktoken for more accurate output token calculation + if enc, encErr := getTokenizer(req.Model); encErr == nil { + if tokenCount, countErr := enc.Count(content); countErr == nil { + usageInfo.OutputTokens = int64(tokenCount) + } + } + // Fallback to character count estimation if tiktoken fails + if usageInfo.OutputTokens == 0 { + usageInfo.OutputTokens = int64(len(content) / 4) + if usageInfo.OutputTokens == 0 { + usageInfo.OutputTokens = 1 + } + } + } + usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens + } + + appendAPIResponseChunk(ctx, e.cfg, []byte(content)) + reporter.publish(ctx, usageInfo) + + // Build response in Claude format for Kiro translator + // stopReason is extracted from upstream response by parseEventStream + kiroResponse := kiroclaude.BuildClaudeResponse(content, toolUses, req.Model, usageInfo, stopReason) + out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, kiroResponse, nil) + resp = cliproxyexecutor.Response{Payload: []byte(out)} + return resp, nil + } + // Inner retry loop exhausted for this endpoint, try next endpoint + // Note: This code is unreachable because all paths in the inner loop + // either return or continue. Kept as comment for documentation. + } + + // All endpoints exhausted + if last429Err != nil { + return resp, last429Err + } + return resp, fmt.Errorf("kiro: all endpoints exhausted") +} + +// ExecuteStream handles streaming requests to Kiro API. +// Supports automatic token refresh on 401/403 errors and quota fallback on 429. +func (e *KiroExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + accessToken, profileArn := kiroCredentials(auth) + if accessToken == "" { + return nil, fmt.Errorf("kiro: access token not found in auth") + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + // Check if token is expired before making request + if e.isTokenExpired(accessToken) { + log.Infof("kiro: access token expired, attempting refresh before stream request") + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Warnf("kiro: pre-request token refresh failed: %v", refreshErr) + } else if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + log.Infof("kiro: token refreshed successfully before stream request") + } + } + + from := opts.SourceFormat + to := sdktranslator.FromString("kiro") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + + kiroModelID := e.mapModelToKiro(req.Model) + + // Determine agentic mode and effective profile ARN using helper functions + isAgentic, isChatOnly := determineAgenticMode(req.Model) + effectiveProfileArn := getEffectiveProfileArnWithWarning(auth, profileArn) + + // Execute stream with retry on 401/403 and 429 (quota exhausted) + // Note: currentOrigin and kiroPayload are built inside executeStreamWithRetry for each endpoint + return e.executeStreamWithRetry(ctx, auth, req, opts, accessToken, effectiveProfileArn, nil, body, from, reporter, "", kiroModelID, isAgentic, isChatOnly) +} + +// executeStreamWithRetry performs the streaming HTTP request with automatic retry on auth errors. +// Supports automatic fallback between endpoints with different quotas: +// - Amazon Q endpoint (CLI origin) uses Amazon Q Developer quota +// - CodeWhisperer endpoint (AI_EDITOR origin) uses Kiro IDE quota +// Also supports multi-endpoint fallback similar to Antigravity implementation. +func (e *KiroExecutor) executeStreamWithRetry(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, accessToken, profileArn string, kiroPayload, body []byte, from sdktranslator.Format, reporter *usageReporter, currentOrigin, kiroModelID string, isAgentic, isChatOnly bool) (<-chan cliproxyexecutor.StreamChunk, error) { + maxRetries := 2 // Allow retries for token refresh + endpoint fallback + endpointConfigs := getKiroEndpointConfigs(auth) + var last429Err error + + for endpointIdx := 0; endpointIdx < len(endpointConfigs); endpointIdx++ { + endpointConfig := endpointConfigs[endpointIdx] + url := endpointConfig.URL + // Use this endpoint's compatible Origin (critical for avoiding 403 errors) + currentOrigin = endpointConfig.Origin + + // Rebuild payload with the correct origin for this endpoint + // Each endpoint requires its matching Origin value in the request body + kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) + // Kiro API always returns tags regardless of whether thinking mode was requested + // So we always enable thinking parsing for Kiro responses + thinkingEnabled := true + + log.Debugf("kiro: stream trying endpoint %d/%d: %s (Name: %s, Origin: %s)", + endpointIdx+1, len(endpointConfigs), url, endpointConfig.Name, currentOrigin) + + for attempt := 0; attempt <= maxRetries; attempt++ { + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(kiroPayload)) + if err != nil { + return nil, err + } + + httpReq.Header.Set("Content-Type", kiroContentType) + httpReq.Header.Set("Authorization", "Bearer "+accessToken) + httpReq.Header.Set("Accept", kiroAcceptStream) + // Use endpoint-specific X-Amz-Target (critical for avoiding 403 errors) + httpReq.Header.Set("X-Amz-Target", endpointConfig.AmzTarget) + httpReq.Header.Set("User-Agent", kiroUserAgent) + httpReq.Header.Set("X-Amz-User-Agent", kiroFullUserAgent) + httpReq.Header.Set("Amz-Sdk-Request", "attempt=1; max=3") + httpReq.Header.Set("Amz-Sdk-Invocation-Id", uuid.New().String()) + + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(httpReq, attrs) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: kiroPayload, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + return nil, err + } + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + // Handle 429 errors (quota exhausted) - try next endpoint + // Each endpoint has its own quota pool, so we can try different endpoints + if httpResp.StatusCode == 429 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + // Preserve last 429 so callers can correctly backoff when all endpoints are exhausted + last429Err = statusErr{code: httpResp.StatusCode, msg: string(respBody)} + + log.Warnf("kiro: stream %s endpoint quota exhausted (429), will try next endpoint, body: %s", + endpointConfig.Name, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + + // Break inner retry loop to try next endpoint (which has different quota) + break + } + + // Handle 5xx server errors with exponential backoff retry + if httpResp.StatusCode >= 500 && httpResp.StatusCode < 600 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + if attempt < maxRetries { + // Exponential backoff: 1s, 2s, 4s... (max 30s) + backoff := time.Duration(1< 30*time.Second { + backoff = 30 * time.Second + } + log.Warnf("kiro: stream server error %d, retrying in %v (attempt %d/%d)", httpResp.StatusCode, backoff, attempt+1, maxRetries) + time.Sleep(backoff) + continue + } + log.Errorf("kiro: stream server error %d after %d retries", httpResp.StatusCode, maxRetries) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 400 errors - Credential/Validation issues + // Do NOT switch endpoints - return error immediately + if httpResp.StatusCode == 400 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + log.Warnf("kiro: received 400 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, summarizeErrorBody(httpResp.Header.Get("Content-Type"), respBody)) + + // 400 errors indicate request validation issues - return immediately without retry + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 401 errors with token refresh and retry + // 401 = Unauthorized (token expired/invalid) - refresh token + if httpResp.StatusCode == 401 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + if attempt < maxRetries { + log.Warnf("kiro: stream received 401 error, attempting token refresh and retry (attempt %d/%d)", attempt+1, maxRetries+1) + + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Errorf("kiro: token refresh failed: %v", refreshErr) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + // Rebuild payload with new profile ARN if changed + kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) + log.Infof("kiro: token refreshed successfully, retrying stream request") + continue + } + } + + log.Warnf("kiro stream error, status: 401, body: %s", string(respBody)) + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 402 errors - Monthly Limit Reached + if httpResp.StatusCode == 402 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + log.Warnf("kiro: stream received 402 (monthly limit). Upstream body: %s", string(respBody)) + + // Return upstream error body directly + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + // Handle 403 errors - Access Denied / Token Expired + // Do NOT switch endpoints for 403 errors + if httpResp.StatusCode == 403 { + respBody, _ := io.ReadAll(httpResp.Body) + _ = httpResp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, respBody) + + // Log the 403 error details for debugging + log.Warnf("kiro: stream received 403 error (attempt %d/%d), body: %s", attempt+1, maxRetries+1, string(respBody)) + + respBodyStr := string(respBody) + + // Check for SUSPENDED status - return immediately without retry + if strings.Contains(respBodyStr, "SUSPENDED") || strings.Contains(respBodyStr, "TEMPORARILY_SUSPENDED") { + log.Errorf("kiro: account is suspended, cannot proceed") + return nil, statusErr{code: httpResp.StatusCode, msg: "account suspended: " + string(respBody)} + } + + // Check if this looks like a token-related 403 (some APIs return 403 for expired tokens) + isTokenRelated := strings.Contains(respBodyStr, "token") || + strings.Contains(respBodyStr, "expired") || + strings.Contains(respBodyStr, "invalid") || + strings.Contains(respBodyStr, "unauthorized") + + if isTokenRelated && attempt < maxRetries { + log.Warnf("kiro: 403 appears token-related, attempting token refresh") + refreshedAuth, refreshErr := e.Refresh(ctx, auth) + if refreshErr != nil { + log.Errorf("kiro: token refresh failed: %v", refreshErr) + // Token refresh failed - return error immediately + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + if refreshedAuth != nil { + auth = refreshedAuth + accessToken, profileArn = kiroCredentials(auth) + kiroPayload, _ = buildKiroPayloadForFormat(body, kiroModelID, profileArn, currentOrigin, isAgentic, isChatOnly, from) + log.Infof("kiro: token refreshed for 403, retrying stream request") + continue + } + } + + // For non-token 403 or after max retries, return error immediately + // Do NOT switch endpoints for 403 errors + log.Warnf("kiro: 403 error, returning immediately (no endpoint switch)") + return nil, statusErr{code: httpResp.StatusCode, msg: string(respBody)} + } + + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("kiro stream error, status: %d, body: %s", httpResp.StatusCode, string(b)) + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + return nil, statusErr{code: httpResp.StatusCode, msg: string(b)} + } + + out := make(chan cliproxyexecutor.StreamChunk) + + go func(resp *http.Response, thinkingEnabled bool) { + defer close(out) + defer func() { + if r := recover(); r != nil { + log.Errorf("kiro: panic in stream handler: %v", r) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("internal error: %v", r)} + } + }() + defer func() { + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("response body close error: %v", errClose) + } + }() + + // Kiro API always returns tags regardless of request parameters + // So we always enable thinking parsing for Kiro responses + log.Debugf("kiro: stream thinkingEnabled = %v (always true for Kiro)", thinkingEnabled) + + e.streamToChannel(ctx, resp.Body, out, from, req.Model, opts.OriginalRequest, body, reporter, thinkingEnabled) + }(httpResp, thinkingEnabled) + + return out, nil + } + // Inner retry loop exhausted for this endpoint, try next endpoint + // Note: This code is unreachable because all paths in the inner loop + // either return or continue. Kept as comment for documentation. + } + + // All endpoints exhausted + if last429Err != nil { + return nil, last429Err + } + return nil, fmt.Errorf("kiro: stream all endpoints exhausted") +} + +// kiroCredentials extracts access token and profile ARN from auth. +func kiroCredentials(auth *cliproxyauth.Auth) (accessToken, profileArn string) { + if auth == nil { + return "", "" + } + + // Try Metadata first (wrapper format) + if auth.Metadata != nil { + if token, ok := auth.Metadata["access_token"].(string); ok { + accessToken = token + } + if arn, ok := auth.Metadata["profile_arn"].(string); ok { + profileArn = arn + } + } + + // Try Attributes + if accessToken == "" && auth.Attributes != nil { + accessToken = auth.Attributes["access_token"] + profileArn = auth.Attributes["profile_arn"] + } + + // Try direct fields from flat JSON format (new AWS Builder ID format) + if accessToken == "" && auth.Metadata != nil { + if token, ok := auth.Metadata["accessToken"].(string); ok { + accessToken = token + } + if arn, ok := auth.Metadata["profileArn"].(string); ok { + profileArn = arn + } + } + + return accessToken, profileArn +} + +// findRealThinkingEndTag finds the real end tag, skipping false positives. +// Returns -1 if no real end tag is found. +// +// Real tags from Kiro API have specific characteristics: +// - Usually preceded by newline (.\n) +// - Usually followed by newline (\n\n) +// - Not inside code blocks or inline code +// +// False positives (discussion text) have characteristics: +// - In the middle of a sentence +// - Preceded by discussion words like "标签", "tag", "returns" +// - Inside code blocks or inline code +// +// Parameters: +// - content: the content to search in +// - alreadyInCodeBlock: whether we're already inside a code block from previous chunks +// - alreadyInInlineCode: whether we're already inside inline code from previous chunks +func findRealThinkingEndTag(content string, alreadyInCodeBlock, alreadyInInlineCode bool) int { + searchStart := 0 + for { + endIdx := strings.Index(content[searchStart:], kirocommon.ThinkingEndTag) + if endIdx < 0 { + return -1 + } + endIdx += searchStart // Adjust to absolute position + + textBeforeEnd := content[:endIdx] + textAfterEnd := content[endIdx+len(kirocommon.ThinkingEndTag):] + + // Check 1: Is it inside inline code? + // Count backticks in current content and add state from previous chunks + backtickCount := strings.Count(textBeforeEnd, "`") + effectiveInInlineCode := alreadyInInlineCode + if backtickCount%2 == 1 { + effectiveInInlineCode = !effectiveInInlineCode + } + if effectiveInInlineCode { + log.Debugf("kiro: found inside inline code at pos %d, skipping", endIdx) + searchStart = endIdx + len(kirocommon.ThinkingEndTag) + continue + } + + // Check 2: Is it inside a code block? + // Count fences in current content and add state from previous chunks + fenceCount := strings.Count(textBeforeEnd, "```") + altFenceCount := strings.Count(textBeforeEnd, "~~~") + effectiveInCodeBlock := alreadyInCodeBlock + if fenceCount%2 == 1 || altFenceCount%2 == 1 { + effectiveInCodeBlock = !effectiveInCodeBlock + } + if effectiveInCodeBlock { + log.Debugf("kiro: found inside code block at pos %d, skipping", endIdx) + searchStart = endIdx + len(kirocommon.ThinkingEndTag) + continue + } + + // Check 3: Real tags are usually preceded by newline or at start + // and followed by newline or at end. Check the format. + charBeforeTag := byte(0) + if endIdx > 0 { + charBeforeTag = content[endIdx-1] + } + charAfterTag := byte(0) + if len(textAfterEnd) > 0 { + charAfterTag = textAfterEnd[0] + } + + // Real end tag format: preceded by newline OR end of sentence (. ! ?) + // and followed by newline OR end of content + isPrecededByNewlineOrSentenceEnd := charBeforeTag == '\n' || charBeforeTag == '.' || + charBeforeTag == '!' || charBeforeTag == '?' || charBeforeTag == 0 + isFollowedByNewlineOrEnd := charAfterTag == '\n' || charAfterTag == 0 + + // If the tag has proper formatting (newline before/after), it's likely real + if isPrecededByNewlineOrSentenceEnd && isFollowedByNewlineOrEnd { + log.Debugf("kiro: found properly formatted at pos %d", endIdx) + return endIdx + } + + // Check 4: Is the tag preceded by discussion keywords on the same line? + lastNewlineIdx := strings.LastIndex(textBeforeEnd, "\n") + lineBeforeTag := textBeforeEnd + if lastNewlineIdx >= 0 { + lineBeforeTag = textBeforeEnd[lastNewlineIdx+1:] + } + lineBeforeTagLower := strings.ToLower(lineBeforeTag) + + // Discussion patterns - if found, this is likely discussion text + discussionPatterns := []string{ + "标签", "返回", "输出", "包含", "使用", "解析", "转换", "生成", // Chinese + "tag", "return", "output", "contain", "use", "parse", "emit", "convert", "generate", // English + "", // discussing both tags together + "``", // explicitly in inline code + } + isDiscussion := false + for _, pattern := range discussionPatterns { + if strings.Contains(lineBeforeTagLower, pattern) { + isDiscussion = true + break + } + } + if isDiscussion { + log.Debugf("kiro: found after discussion text at pos %d, skipping", endIdx) + searchStart = endIdx + len(kirocommon.ThinkingEndTag) + continue + } + + // Check 5: Is there text immediately after on the same line? + // Real end tags don't have text immediately after on the same line + if len(textAfterEnd) > 0 && charAfterTag != '\n' && charAfterTag != 0 { + // Find the next newline + nextNewline := strings.Index(textAfterEnd, "\n") + var textOnSameLine string + if nextNewline >= 0 { + textOnSameLine = textAfterEnd[:nextNewline] + } else { + textOnSameLine = textAfterEnd + } + // If there's non-whitespace text on the same line after the tag, it's discussion + if strings.TrimSpace(textOnSameLine) != "" { + log.Debugf("kiro: found with text after on same line at pos %d, skipping", endIdx) + searchStart = endIdx + len(kirocommon.ThinkingEndTag) + continue + } + } + + // Check 6: Is there another tag after this ? + if strings.Contains(textAfterEnd, kirocommon.ThinkingStartTag) { + nextStartIdx := strings.Index(textAfterEnd, kirocommon.ThinkingStartTag) + textBeforeNextStart := textAfterEnd[:nextStartIdx] + nextBacktickCount := strings.Count(textBeforeNextStart, "`") + nextFenceCount := strings.Count(textBeforeNextStart, "```") + nextAltFenceCount := strings.Count(textBeforeNextStart, "~~~") + + // If the next is NOT in code, then this is discussion text + if nextBacktickCount%2 == 0 && nextFenceCount%2 == 0 && nextAltFenceCount%2 == 0 { + log.Debugf("kiro: found followed by at pos %d, likely discussion text, skipping", endIdx) + searchStart = endIdx + len(kirocommon.ThinkingEndTag) + continue + } + } + + // This looks like a real end tag + return endIdx + } +} + +// determineAgenticMode determines if the model is an agentic or chat-only variant. +// Returns (isAgentic, isChatOnly) based on model name suffixes. +func determineAgenticMode(model string) (isAgentic, isChatOnly bool) { + isAgentic = strings.HasSuffix(model, "-agentic") + isChatOnly = strings.HasSuffix(model, "-chat") + return isAgentic, isChatOnly +} + +// getEffectiveProfileArn determines if profileArn should be included based on auth method. +// profileArn is only needed for social auth (Google OAuth), not for builder-id (AWS SSO). +func getEffectiveProfileArn(auth *cliproxyauth.Auth, profileArn string) string { + if auth != nil && auth.Metadata != nil { + if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { + return "" // Don't include profileArn for builder-id auth + } + } + return profileArn +} + +// getEffectiveProfileArnWithWarning determines if profileArn should be included based on auth method, +// and logs a warning if profileArn is missing for non-builder-id auth. +// This consolidates the auth_method check that was previously done separately. +func getEffectiveProfileArnWithWarning(auth *cliproxyauth.Auth, profileArn string) string { + if auth != nil && auth.Metadata != nil { + if authMethod, ok := auth.Metadata["auth_method"].(string); ok && authMethod == "builder-id" { + // builder-id auth doesn't need profileArn + return "" + } + } + // For non-builder-id auth (social auth), profileArn is required + if profileArn == "" { + log.Warnf("kiro: profile ARN not found in auth, API calls may fail") + } + return profileArn +} + +// mapModelToKiro maps external model names to Kiro model IDs. +// Supports both Kiro and Amazon Q prefixes since they use the same API. +// Agentic variants (-agentic suffix) map to the same backend model IDs. +func (e *KiroExecutor) mapModelToKiro(model string) string { + modelMap := map[string]string{ + // Amazon Q format (amazonq- prefix) - same API as Kiro + "amazonq-auto": "auto", + "amazonq-claude-opus-4-5": "claude-opus-4.5", + "amazonq-claude-sonnet-4-5": "claude-sonnet-4.5", + "amazonq-claude-sonnet-4-5-20250929": "claude-sonnet-4.5", + "amazonq-claude-sonnet-4": "claude-sonnet-4", + "amazonq-claude-sonnet-4-20250514": "claude-sonnet-4", + "amazonq-claude-haiku-4-5": "claude-haiku-4.5", + // Kiro format (kiro- prefix) - valid model names that should be preserved + "kiro-claude-opus-4-5": "claude-opus-4.5", + "kiro-claude-sonnet-4-5": "claude-sonnet-4.5", + "kiro-claude-sonnet-4-5-20250929": "claude-sonnet-4.5", + "kiro-claude-sonnet-4": "claude-sonnet-4", + "kiro-claude-sonnet-4-20250514": "claude-sonnet-4", + "kiro-claude-haiku-4-5": "claude-haiku-4.5", + "kiro-auto": "auto", + // Native format (no prefix) - used by Kiro IDE directly + "claude-opus-4-5": "claude-opus-4.5", + "claude-opus-4.5": "claude-opus-4.5", + "claude-haiku-4-5": "claude-haiku-4.5", + "claude-haiku-4.5": "claude-haiku-4.5", + "claude-sonnet-4-5": "claude-sonnet-4.5", + "claude-sonnet-4-5-20250929": "claude-sonnet-4.5", + "claude-sonnet-4.5": "claude-sonnet-4.5", + "claude-sonnet-4": "claude-sonnet-4", + "claude-sonnet-4-20250514": "claude-sonnet-4", + "auto": "auto", + // Agentic variants (same backend model IDs, but with special system prompt) + "claude-opus-4.5-agentic": "claude-opus-4.5", + "claude-sonnet-4.5-agentic": "claude-sonnet-4.5", + "claude-sonnet-4-agentic": "claude-sonnet-4", + "claude-haiku-4.5-agentic": "claude-haiku-4.5", + "kiro-claude-opus-4-5-agentic": "claude-opus-4.5", + "kiro-claude-sonnet-4-5-agentic": "claude-sonnet-4.5", + "kiro-claude-sonnet-4-agentic": "claude-sonnet-4", + "kiro-claude-haiku-4-5-agentic": "claude-haiku-4.5", + } + if kiroID, ok := modelMap[model]; ok { + return kiroID + } + + // Smart fallback: try to infer model type from name patterns + modelLower := strings.ToLower(model) + + // Check for Haiku variants + if strings.Contains(modelLower, "haiku") { + log.Debugf("kiro: unknown Haiku model '%s', mapping to claude-haiku-4.5", model) + return "claude-haiku-4.5" + } + + // Check for Sonnet variants + if strings.Contains(modelLower, "sonnet") { + // Check for specific version patterns + if strings.Contains(modelLower, "3-7") || strings.Contains(modelLower, "3.7") { + log.Debugf("kiro: unknown Sonnet 3.7 model '%s', mapping to claude-3-7-sonnet-20250219", model) + return "claude-3-7-sonnet-20250219" + } + if strings.Contains(modelLower, "4-5") || strings.Contains(modelLower, "4.5") { + log.Debugf("kiro: unknown Sonnet 4.5 model '%s', mapping to claude-sonnet-4.5", model) + return "claude-sonnet-4.5" + } + // Default to Sonnet 4 + log.Debugf("kiro: unknown Sonnet model '%s', mapping to claude-sonnet-4", model) + return "claude-sonnet-4" + } + + // Check for Opus variants + if strings.Contains(modelLower, "opus") { + log.Debugf("kiro: unknown Opus model '%s', mapping to claude-opus-4.5", model) + return "claude-opus-4.5" + } + + // Final fallback to Sonnet 4.5 (most commonly used model) + log.Warnf("kiro: unknown model '%s', falling back to claude-sonnet-4.5", model) + return "claude-sonnet-4.5" +} + +// EventStreamError represents an Event Stream processing error +type EventStreamError struct { + Type string // "fatal", "malformed" + Message string + Cause error +} + +func (e *EventStreamError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("event stream %s: %s: %v", e.Type, e.Message, e.Cause) + } + return fmt.Sprintf("event stream %s: %s", e.Type, e.Message) +} + +// eventStreamMessage represents a parsed AWS Event Stream message +type eventStreamMessage struct { + EventType string // Event type from headers (e.g., "assistantResponseEvent") + Payload []byte // JSON payload of the message +} + +// NOTE: Request building functions moved to internal/translator/kiro/claude/kiro_claude_request.go +// The executor now uses kiroclaude.BuildKiroPayload() instead + +// parseEventStream parses AWS Event Stream binary format. +// Extracts text content, tool uses, and stop_reason from the response. +// Supports embedded [Called ...] tool calls and input buffering for toolUseEvent. +// Returns: content, toolUses, usageInfo, stopReason, error +func (e *KiroExecutor) parseEventStream(body io.Reader) (string, []kiroclaude.KiroToolUse, usage.Detail, string, error) { + var content strings.Builder + var toolUses []kiroclaude.KiroToolUse + var usageInfo usage.Detail + var stopReason string // Extracted from upstream response + reader := bufio.NewReader(body) + + // Tool use state tracking for input buffering and deduplication + processedIDs := make(map[string]bool) + var currentToolUse *kiroclaude.ToolUseState + + // Upstream usage tracking - Kiro API returns credit usage and context percentage + var upstreamContextPercentage float64 // Context usage percentage from upstream (e.g., 78.56) + + for { + msg, eventErr := e.readEventStreamMessage(reader) + if eventErr != nil { + log.Errorf("kiro: parseEventStream error: %v", eventErr) + return content.String(), toolUses, usageInfo, stopReason, eventErr + } + if msg == nil { + // Normal end of stream (EOF) + break + } + + eventType := msg.EventType + payload := msg.Payload + if len(payload) == 0 { + continue + } + + var event map[string]interface{} + if err := json.Unmarshal(payload, &event); err != nil { + log.Debugf("kiro: skipping malformed event: %v", err) + continue + } + + // Check for error/exception events in the payload (Kiro API may return errors with HTTP 200) + // These can appear as top-level fields or nested within the event + if errType, hasErrType := event["_type"].(string); hasErrType { + // AWS-style error: {"_type": "com.amazon.aws.codewhisperer#ValidationException", "message": "..."} + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } + log.Errorf("kiro: received AWS error in event stream: type=%s, message=%s", errType, errMsg) + return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s - %s", errType, errMsg) + } + if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") { + // Generic error event + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } else if errObj, ok := event["error"].(map[string]interface{}); ok { + if msg, ok := errObj["message"].(string); ok { + errMsg = msg + } + } + log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg) + return "", nil, usageInfo, stopReason, fmt.Errorf("kiro API error: %s", errMsg) + } + + // Extract stop_reason from various event formats + // Kiro/Amazon Q API may include stop_reason in different locations + if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stop_reason (top-level): %s", stopReason) + } + if sr := kirocommon.GetString(event, "stopReason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stopReason (top-level): %s", stopReason) + } + + // Handle different event types + switch eventType { + case "followupPromptEvent": + // Filter out followupPrompt events - these are UI suggestions, not content + log.Debugf("kiro: parseEventStream ignoring followupPrompt event") + continue + + case "assistantResponseEvent": + if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { + if contentText, ok := assistantResp["content"].(string); ok { + content.WriteString(contentText) + } + // Extract stop_reason from assistantResponseEvent + if sr := kirocommon.GetString(assistantResp, "stop_reason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stop_reason in assistantResponseEvent: %s", stopReason) + } + if sr := kirocommon.GetString(assistantResp, "stopReason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stopReason in assistantResponseEvent: %s", stopReason) + } + // Extract tool uses from response + if toolUsesRaw, ok := assistantResp["toolUses"].([]interface{}); ok { + for _, tuRaw := range toolUsesRaw { + if tu, ok := tuRaw.(map[string]interface{}); ok { + toolUseID := kirocommon.GetStringValue(tu, "toolUseId") + // Check for duplicate + if processedIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate tool use from assistantResponse: %s", toolUseID) + continue + } + processedIDs[toolUseID] = true + + toolUse := kiroclaude.KiroToolUse{ + ToolUseID: toolUseID, + Name: kirocommon.GetStringValue(tu, "name"), + } + if input, ok := tu["input"].(map[string]interface{}); ok { + toolUse.Input = input + } + toolUses = append(toolUses, toolUse) + } + } + } + } + // Also try direct format + if contentText, ok := event["content"].(string); ok { + content.WriteString(contentText) + } + // Direct tool uses + if toolUsesRaw, ok := event["toolUses"].([]interface{}); ok { + for _, tuRaw := range toolUsesRaw { + if tu, ok := tuRaw.(map[string]interface{}); ok { + toolUseID := kirocommon.GetStringValue(tu, "toolUseId") + // Check for duplicate + if processedIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate direct tool use: %s", toolUseID) + continue + } + processedIDs[toolUseID] = true + + toolUse := kiroclaude.KiroToolUse{ + ToolUseID: toolUseID, + Name: kirocommon.GetStringValue(tu, "name"), + } + if input, ok := tu["input"].(map[string]interface{}); ok { + toolUse.Input = input + } + toolUses = append(toolUses, toolUse) + } + } + } + + case "toolUseEvent": + // Handle dedicated tool use events with input buffering + completedToolUses, newState := kiroclaude.ProcessToolUseEvent(event, currentToolUse, processedIDs) + currentToolUse = newState + toolUses = append(toolUses, completedToolUses...) + + case "supplementaryWebLinksEvent": + if inputTokens, ok := event["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + + case "messageStopEvent", "message_stop": + // Handle message stop events which may contain stop_reason + if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stop_reason in messageStopEvent: %s", stopReason) + } + if sr := kirocommon.GetString(event, "stopReason"); sr != "" { + stopReason = sr + log.Debugf("kiro: parseEventStream found stopReason in messageStopEvent: %s", stopReason) + } + + case "messageMetadataEvent": + // Handle message metadata events which may contain token counts + if metadata, ok := event["messageMetadataEvent"].(map[string]interface{}); ok { + if inputTokens, ok := metadata["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + log.Debugf("kiro: parseEventStream found inputTokens in messageMetadataEvent: %d", usageInfo.InputTokens) + } + if outputTokens, ok := metadata["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + log.Debugf("kiro: parseEventStream found outputTokens in messageMetadataEvent: %d", usageInfo.OutputTokens) + } + if totalTokens, ok := metadata["totalTokens"].(float64); ok { + usageInfo.TotalTokens = int64(totalTokens) + log.Debugf("kiro: parseEventStream found totalTokens in messageMetadataEvent: %d", usageInfo.TotalTokens) + } + } + + case "usageEvent", "usage": + // Handle dedicated usage events + if inputTokens, ok := event["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + log.Debugf("kiro: parseEventStream found inputTokens in usageEvent: %d", usageInfo.InputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + log.Debugf("kiro: parseEventStream found outputTokens in usageEvent: %d", usageInfo.OutputTokens) + } + if totalTokens, ok := event["totalTokens"].(float64); ok { + usageInfo.TotalTokens = int64(totalTokens) + log.Debugf("kiro: parseEventStream found totalTokens in usageEvent: %d", usageInfo.TotalTokens) + } + // Also check nested usage object + if usageObj, ok := event["usage"].(map[string]interface{}); ok { + if inputTokens, ok := usageObj["input_tokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + if outputTokens, ok := usageObj["output_tokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + if totalTokens, ok := usageObj["total_tokens"].(float64); ok { + usageInfo.TotalTokens = int64(totalTokens) + } + log.Debugf("kiro: parseEventStream found usage object: input=%d, output=%d, total=%d", + usageInfo.InputTokens, usageInfo.OutputTokens, usageInfo.TotalTokens) + } + + case "metricsEvent": + // Handle metrics events which may contain usage data + if metrics, ok := event["metricsEvent"].(map[string]interface{}); ok { + if inputTokens, ok := metrics["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + if outputTokens, ok := metrics["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + log.Debugf("kiro: parseEventStream found metricsEvent: input=%d, output=%d", + usageInfo.InputTokens, usageInfo.OutputTokens) + } + + default: + // Check for contextUsagePercentage in any event + if ctxPct, ok := event["contextUsagePercentage"].(float64); ok { + upstreamContextPercentage = ctxPct + log.Debugf("kiro: parseEventStream received context usage: %.2f%%", upstreamContextPercentage) + } + // Log unknown event types for debugging (to discover new event formats) + log.Debugf("kiro: parseEventStream unknown event type: %s, payload: %s", eventType, string(payload)) + } + + // Check for direct token fields in any event (fallback) + if usageInfo.InputTokens == 0 { + if inputTokens, ok := event["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + log.Debugf("kiro: parseEventStream found direct inputTokens: %d", usageInfo.InputTokens) + } + } + if usageInfo.OutputTokens == 0 { + if outputTokens, ok := event["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + log.Debugf("kiro: parseEventStream found direct outputTokens: %d", usageInfo.OutputTokens) + } + } + + // Check for usage object in any event (OpenAI format) + if usageInfo.InputTokens == 0 || usageInfo.OutputTokens == 0 { + if usageObj, ok := event["usage"].(map[string]interface{}); ok { + if usageInfo.InputTokens == 0 { + if inputTokens, ok := usageObj["input_tokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + } + if usageInfo.OutputTokens == 0 { + if outputTokens, ok := usageObj["output_tokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + } + if usageInfo.TotalTokens == 0 { + if totalTokens, ok := usageObj["total_tokens"].(float64); ok { + usageInfo.TotalTokens = int64(totalTokens) + } + } + log.Debugf("kiro: parseEventStream found usage object (fallback): input=%d, output=%d, total=%d", + usageInfo.InputTokens, usageInfo.OutputTokens, usageInfo.TotalTokens) + } + } + + // Also check nested supplementaryWebLinksEvent + if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok { + if inputTokens, ok := usageEvent["inputTokens"].(float64); ok { + usageInfo.InputTokens = int64(inputTokens) + } + if outputTokens, ok := usageEvent["outputTokens"].(float64); ok { + usageInfo.OutputTokens = int64(outputTokens) + } + } + } + + // Parse embedded tool calls from content (e.g., [Called tool_name with args: {...}]) + contentStr := content.String() + cleanedContent, embeddedToolUses := kiroclaude.ParseEmbeddedToolCalls(contentStr, processedIDs) + toolUses = append(toolUses, embeddedToolUses...) + + // Deduplicate all tool uses + toolUses = kiroclaude.DeduplicateToolUses(toolUses) + + // Apply fallback logic for stop_reason if not provided by upstream + // Priority: upstream stopReason > tool_use detection > end_turn default + if stopReason == "" { + if len(toolUses) > 0 { + stopReason = "tool_use" + log.Debugf("kiro: parseEventStream using fallback stop_reason: tool_use (detected %d tool uses)", len(toolUses)) + } else { + stopReason = "end_turn" + log.Debugf("kiro: parseEventStream using fallback stop_reason: end_turn") + } + } + + // Log warning if response was truncated due to max_tokens + if stopReason == "max_tokens" { + log.Warnf("kiro: response truncated due to max_tokens limit") + } + + // Use contextUsagePercentage to calculate more accurate input tokens + // Kiro model has 200k max context, contextUsagePercentage represents the percentage used + // Formula: input_tokens = contextUsagePercentage * 200000 / 100 + if upstreamContextPercentage > 0 { + calculatedInputTokens := int64(upstreamContextPercentage * 200000 / 100) + if calculatedInputTokens > 0 { + localEstimate := usageInfo.InputTokens + usageInfo.InputTokens = calculatedInputTokens + usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens + log.Infof("kiro: parseEventStream using contextUsagePercentage (%.2f%%) to calculate input tokens: %d (local estimate was: %d)", + upstreamContextPercentage, calculatedInputTokens, localEstimate) + } + } + + return cleanedContent, toolUses, usageInfo, stopReason, nil +} + +// readEventStreamMessage reads and validates a single AWS Event Stream message. +// Returns the parsed message or a structured error for different failure modes. +// This function implements boundary protection and detailed error classification. +// +// AWS Event Stream binary format: +// - Prelude (12 bytes): total_length (4) + headers_length (4) + prelude_crc (4) +// - Headers (variable): header entries +// - Payload (variable): JSON data +// - Message CRC (4 bytes): CRC32C of entire message (not validated, just skipped) +func (e *KiroExecutor) readEventStreamMessage(reader *bufio.Reader) (*eventStreamMessage, *EventStreamError) { + // Read prelude (first 12 bytes: total_len + headers_len + prelude_crc) + prelude := make([]byte, 12) + _, err := io.ReadFull(reader, prelude) + if err == io.EOF { + return nil, nil // Normal end of stream + } + if err != nil { + return nil, &EventStreamError{ + Type: ErrStreamFatal, + Message: "failed to read prelude", + Cause: err, + } + } + + totalLength := binary.BigEndian.Uint32(prelude[0:4]) + headersLength := binary.BigEndian.Uint32(prelude[4:8]) + // Note: prelude[8:12] is prelude_crc - we read it but don't validate (no CRC check per requirements) + + // Boundary check: minimum frame size + if totalLength < minEventStreamFrameSize { + return nil, &EventStreamError{ + Type: ErrStreamMalformed, + Message: fmt.Sprintf("invalid message length: %d (minimum is %d)", totalLength, minEventStreamFrameSize), + } + } + + // Boundary check: maximum message size + if totalLength > maxEventStreamMsgSize { + return nil, &EventStreamError{ + Type: ErrStreamMalformed, + Message: fmt.Sprintf("message too large: %d bytes (maximum is %d)", totalLength, maxEventStreamMsgSize), + } + } + + // Boundary check: headers length within message bounds + // Message structure: prelude(12) + headers(headersLength) + payload + message_crc(4) + // So: headersLength must be <= totalLength - 16 (12 for prelude + 4 for message_crc) + if headersLength > totalLength-16 { + return nil, &EventStreamError{ + Type: ErrStreamMalformed, + Message: fmt.Sprintf("headers length %d exceeds message bounds (total: %d)", headersLength, totalLength), + } + } + + // Read the rest of the message (total - 12 bytes already read) + remaining := make([]byte, totalLength-12) + _, err = io.ReadFull(reader, remaining) + if err != nil { + return nil, &EventStreamError{ + Type: ErrStreamFatal, + Message: "failed to read message body", + Cause: err, + } + } + + // Extract event type from headers + // Headers start at beginning of 'remaining', length is headersLength + var eventType string + if headersLength > 0 && headersLength <= uint32(len(remaining)) { + eventType = e.extractEventTypeFromBytes(remaining[:headersLength]) + } + + // Calculate payload boundaries + // Payload starts after headers, ends before message_crc (last 4 bytes) + payloadStart := headersLength + payloadEnd := uint32(len(remaining)) - 4 // Skip message_crc at end + + // Validate payload boundaries + if payloadStart >= payloadEnd { + // No payload, return empty message + return &eventStreamMessage{ + EventType: eventType, + Payload: nil, + }, nil + } + + payload := remaining[payloadStart:payloadEnd] + + return &eventStreamMessage{ + EventType: eventType, + Payload: payload, + }, nil +} + +func skipEventStreamHeaderValue(headers []byte, offset int, valueType byte) (int, bool) { + switch valueType { + case 0, 1: // bool true / bool false + return offset, true + case 2: // byte + if offset+1 > len(headers) { + return offset, false + } + return offset + 1, true + case 3: // short + if offset+2 > len(headers) { + return offset, false + } + return offset + 2, true + case 4: // int + if offset+4 > len(headers) { + return offset, false + } + return offset + 4, true + case 5: // long + if offset+8 > len(headers) { + return offset, false + } + return offset + 8, true + case 6: // byte array (2-byte length + data) + if offset+2 > len(headers) { + return offset, false + } + valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2])) + offset += 2 + if offset+valueLen > len(headers) { + return offset, false + } + return offset + valueLen, true + case 8: // timestamp + if offset+8 > len(headers) { + return offset, false + } + return offset + 8, true + case 9: // uuid + if offset+16 > len(headers) { + return offset, false + } + return offset + 16, true + default: + return offset, false + } +} + +// extractEventTypeFromBytes extracts the event type from raw header bytes (without prelude CRC prefix) +func (e *KiroExecutor) extractEventTypeFromBytes(headers []byte) string { + offset := 0 + for offset < len(headers) { + nameLen := int(headers[offset]) + offset++ + if offset+nameLen > len(headers) { + break + } + name := string(headers[offset : offset+nameLen]) + offset += nameLen + + if offset >= len(headers) { + break + } + valueType := headers[offset] + offset++ + + if valueType == 7 { // String type + if offset+2 > len(headers) { + break + } + valueLen := int(binary.BigEndian.Uint16(headers[offset : offset+2])) + offset += 2 + if offset+valueLen > len(headers) { + break + } + value := string(headers[offset : offset+valueLen]) + offset += valueLen + + if name == ":event-type" { + return value + } + continue + } + + nextOffset, ok := skipEventStreamHeaderValue(headers, offset, valueType) + if !ok { + break + } + offset = nextOffset + } + return "" +} + + +// NOTE: Response building functions moved to internal/translator/kiro/claude/kiro_claude_response.go +// The executor now uses kiroclaude.BuildClaudeResponse() and kiroclaude.ExtractThinkingFromContent() instead + +// streamToChannel converts AWS Event Stream to channel-based streaming. +// Supports tool calling - emits tool_use content blocks when tools are used. +// Includes embedded [Called ...] tool call parsing and input buffering for toolUseEvent. +// Implements duplicate content filtering using lastContentEvent detection (based on AIClient-2-API). +// Extracts stop_reason from upstream events when available. +// thinkingEnabled controls whether tags are parsed - only parse when request enabled thinking. +func (e *KiroExecutor) streamToChannel(ctx context.Context, body io.Reader, out chan<- cliproxyexecutor.StreamChunk, targetFormat sdktranslator.Format, model string, originalReq, claudeBody []byte, reporter *usageReporter, thinkingEnabled bool) { + reader := bufio.NewReaderSize(body, 20*1024*1024) // 20MB buffer to match other providers + var totalUsage usage.Detail + var hasToolUses bool // Track if any tool uses were emitted + var upstreamStopReason string // Track stop_reason from upstream events + + // Tool use state tracking for input buffering and deduplication + processedIDs := make(map[string]bool) + var currentToolUse *kiroclaude.ToolUseState + + // NOTE: Duplicate content filtering removed - it was causing legitimate repeated + // content (like consecutive newlines) to be incorrectly filtered out. + // The previous implementation compared lastContentEvent == contentDelta which + // is too aggressive for streaming scenarios. + + // Streaming token calculation - accumulate content for real-time token counting + // Based on AIClient-2-API implementation + var accumulatedContent strings.Builder + accumulatedContent.Grow(4096) // Pre-allocate 4KB capacity to reduce reallocations + + // Real-time usage estimation state + // These track when to send periodic usage updates during streaming + var lastUsageUpdateLen int // Last accumulated content length when usage was sent + var lastUsageUpdateTime = time.Now() // Last time usage update was sent + var lastReportedOutputTokens int64 // Last reported output token count + + // Upstream usage tracking - Kiro API returns credit usage and context percentage + var upstreamCreditUsage float64 // Credit usage from upstream (e.g., 1.458) + var upstreamContextPercentage float64 // Context usage percentage from upstream (e.g., 78.56) + var hasUpstreamUsage bool // Whether we received usage from upstream + + // Translator param for maintaining tool call state across streaming events + // IMPORTANT: This must persist across all TranslateStream calls + var translatorParam any + + // Thinking mode state tracking - based on amq2api implementation + // Tracks whether we're inside a block and handles partial tags + inThinkBlock := false + pendingStartTagChars := 0 // Number of chars that might be start of + pendingEndTagChars := 0 // Number of chars that might be start of + isThinkingBlockOpen := false // Track if thinking content block is open + thinkingBlockIndex := -1 // Index of the thinking content block + + // Code block state tracking for heuristic thinking tag parsing + // When inside a markdown code block, tags should NOT be parsed + // This prevents false positives when the model outputs code examples containing these tags + inCodeBlock := false + codeFenceType := "" // Track which fence type opened the block ("```" or "~~~") + + // Inline code state tracking - when inside backticks, don't parse thinking tags + // This handles cases like `` being discussed in text + inInlineCode := false + + // Track if we've seen any non-whitespace content before a thinking tag + // Real thinking blocks from Kiro always start at the very beginning of the response + // If we see content before , subsequent tags are likely discussion text + hasSeenNonThinkingContent := false + thinkingBlockCompleted := false // Track if we've already completed a thinking block + + // Pre-calculate input tokens from request if possible + // Kiro uses Claude format, so try Claude format first, then OpenAI format, then fallback + if enc, err := getTokenizer(model); err == nil { + var inputTokens int64 + var countMethod string + + // Try Claude format first (Kiro uses Claude API format) + if inp, err := countClaudeChatTokens(enc, claudeBody); err == nil && inp > 0 { + inputTokens = inp + countMethod = "claude" + } else if inp, err := countOpenAIChatTokens(enc, originalReq); err == nil && inp > 0 { + // Fallback to OpenAI format (for OpenAI-compatible requests) + inputTokens = inp + countMethod = "openai" + } else { + // Final fallback: estimate from raw request size (roughly 4 chars per token) + inputTokens = int64(len(claudeBody) / 4) + if inputTokens == 0 && len(claudeBody) > 0 { + inputTokens = 1 + } + countMethod = "estimate" + } + + totalUsage.InputTokens = inputTokens + log.Debugf("kiro: streamToChannel pre-calculated input tokens: %d (method: %s, claude body: %d bytes, original req: %d bytes)", + totalUsage.InputTokens, countMethod, len(claudeBody), len(originalReq)) + } + + contentBlockIndex := -1 + messageStartSent := false + isTextBlockOpen := false + var outputLen int + + // Ensure usage is published even on early return + defer func() { + reporter.publish(ctx, totalUsage) + }() + + for { + select { + case <-ctx.Done(): + return + default: + } + + msg, eventErr := e.readEventStreamMessage(reader) + if eventErr != nil { + // Log the error + log.Errorf("kiro: streamToChannel error: %v", eventErr) + + // Send error to channel for client notification + out <- cliproxyexecutor.StreamChunk{Err: eventErr} + return + } + if msg == nil { + // Normal end of stream (EOF) + // Flush any incomplete tool use before ending stream + if currentToolUse != nil && !processedIDs[currentToolUse.ToolUseID] { + log.Warnf("kiro: flushing incomplete tool use at EOF: %s (ID: %s)", currentToolUse.Name, currentToolUse.ToolUseID) + fullInput := currentToolUse.InputBuffer.String() + repairedJSON := kiroclaude.RepairJSON(fullInput) + var finalInput map[string]interface{} + if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil { + log.Warnf("kiro: failed to parse incomplete tool input at EOF: %v", err) + finalInput = make(map[string]interface{}) + } + + processedIDs[currentToolUse.ToolUseID] = true + contentBlockIndex++ + + // Send tool_use content block + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", currentToolUse.ToolUseID, currentToolUse.Name) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Send tool input as delta + inputBytes, _ := json.Marshal(finalInput) + inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputBytes), contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Close block + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + hasToolUses = true + currentToolUse = nil + } + + // Flush any pending tag characters at EOF + // These are partial tag prefixes that were held back waiting for more data + // Since no more data is coming, output them as regular text + var pendingText string + if pendingStartTagChars > 0 { + pendingText = kirocommon.ThinkingStartTag[:pendingStartTagChars] + log.Debugf("kiro: flushing pending start tag chars at EOF: %q", pendingText) + pendingStartTagChars = 0 + } + if pendingEndTagChars > 0 { + pendingText += kirocommon.ThinkingEndTag[:pendingEndTagChars] + log.Debugf("kiro: flushing pending end tag chars at EOF: %q", pendingText) + pendingEndTagChars = 0 + } + + // Output pending text if any + if pendingText != "" { + // If we're in a thinking block, output as thinking content + if inThinkBlock && isThinkingBlockOpen { + thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(pendingText, thinkingBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } else { + // Output as regular text + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + claudeEvent := kiroclaude.BuildClaudeStreamEvent(pendingText, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + } + break + } + + eventType := msg.EventType + payload := msg.Payload + if len(payload) == 0 { + continue + } + appendAPIResponseChunk(ctx, e.cfg, payload) + + var event map[string]interface{} + if err := json.Unmarshal(payload, &event); err != nil { + log.Warnf("kiro: failed to unmarshal event payload: %v, raw: %s", err, string(payload)) + continue + } + + // Check for error/exception events in the payload (Kiro API may return errors with HTTP 200) + // These can appear as top-level fields or nested within the event + if errType, hasErrType := event["_type"].(string); hasErrType { + // AWS-style error: {"_type": "com.amazon.aws.codewhisperer#ValidationException", "message": "..."} + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } + log.Errorf("kiro: received AWS error in stream: type=%s, message=%s", errType, errMsg) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("kiro API error: %s - %s", errType, errMsg)} + return + } + if errType, hasErrType := event["type"].(string); hasErrType && (errType == "error" || errType == "exception") { + // Generic error event + errMsg := "" + if msg, ok := event["message"].(string); ok { + errMsg = msg + } else if errObj, ok := event["error"].(map[string]interface{}); ok { + if msg, ok := errObj["message"].(string); ok { + errMsg = msg + } + } + log.Errorf("kiro: received error event in stream: type=%s, message=%s", errType, errMsg) + out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("kiro API error: %s", errMsg)} + return + } + + // Extract stop_reason from various event formats (streaming) + // Kiro/Amazon Q API may include stop_reason in different locations + if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stop_reason (top-level): %s", upstreamStopReason) + } + if sr := kirocommon.GetString(event, "stopReason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stopReason (top-level): %s", upstreamStopReason) + } + + // Send message_start on first event + if !messageStartSent { + msgStart := kiroclaude.BuildClaudeMessageStartEvent(model, totalUsage.InputTokens) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + messageStartSent = true + } + + switch eventType { + case "followupPromptEvent": + // Filter out followupPrompt events - these are UI suggestions, not content + log.Debugf("kiro: streamToChannel ignoring followupPrompt event") + continue + + case "messageStopEvent", "message_stop": + // Handle message stop events which may contain stop_reason + if sr := kirocommon.GetString(event, "stop_reason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stop_reason in messageStopEvent: %s", upstreamStopReason) + } + if sr := kirocommon.GetString(event, "stopReason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stopReason in messageStopEvent: %s", upstreamStopReason) + } + + default: + // Check for upstream usage events from Kiro API + // Format: {"unit":"credit","unitPlural":"credits","usage":1.458} + if unit, ok := event["unit"].(string); ok && unit == "credit" { + if usage, ok := event["usage"].(float64); ok { + upstreamCreditUsage = usage + hasUpstreamUsage = true + log.Debugf("kiro: received upstream credit usage: %.4f", upstreamCreditUsage) + } + } + // Format: {"contextUsagePercentage":78.56} + if ctxPct, ok := event["contextUsagePercentage"].(float64); ok { + upstreamContextPercentage = ctxPct + log.Debugf("kiro: received upstream context usage: %.2f%%", upstreamContextPercentage) + } + + // Check for token counts in unknown events + if inputTokens, ok := event["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + hasUpstreamUsage = true + log.Debugf("kiro: streamToChannel found inputTokens in event %s: %d", eventType, totalUsage.InputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + hasUpstreamUsage = true + log.Debugf("kiro: streamToChannel found outputTokens in event %s: %d", eventType, totalUsage.OutputTokens) + } + if totalTokens, ok := event["totalTokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + log.Debugf("kiro: streamToChannel found totalTokens in event %s: %d", eventType, totalUsage.TotalTokens) + } + + // Check for usage object in unknown events (OpenAI/Claude format) + if usageObj, ok := event["usage"].(map[string]interface{}); ok { + if inputTokens, ok := usageObj["input_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + hasUpstreamUsage = true + } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + hasUpstreamUsage = true + } + if outputTokens, ok := usageObj["output_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + hasUpstreamUsage = true + } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + hasUpstreamUsage = true + } + if totalTokens, ok := usageObj["total_tokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + } + log.Debugf("kiro: streamToChannel found usage object in event %s: input=%d, output=%d, total=%d", + eventType, totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) + } + + // Log unknown event types for debugging (to discover new event formats) + if eventType != "" { + log.Debugf("kiro: streamToChannel unknown event type: %s, payload: %s", eventType, string(payload)) + } + + case "assistantResponseEvent": + var contentDelta string + var toolUses []map[string]interface{} + + if assistantResp, ok := event["assistantResponseEvent"].(map[string]interface{}); ok { + if c, ok := assistantResp["content"].(string); ok { + contentDelta = c + } + // Extract stop_reason from assistantResponseEvent + if sr := kirocommon.GetString(assistantResp, "stop_reason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stop_reason in assistantResponseEvent: %s", upstreamStopReason) + } + if sr := kirocommon.GetString(assistantResp, "stopReason"); sr != "" { + upstreamStopReason = sr + log.Debugf("kiro: streamToChannel found stopReason in assistantResponseEvent: %s", upstreamStopReason) + } + // Extract tool uses from response + if tus, ok := assistantResp["toolUses"].([]interface{}); ok { + for _, tuRaw := range tus { + if tu, ok := tuRaw.(map[string]interface{}); ok { + toolUses = append(toolUses, tu) + } + } + } + } + if contentDelta == "" { + if c, ok := event["content"].(string); ok { + contentDelta = c + } + } + // Direct tool uses + if tus, ok := event["toolUses"].([]interface{}); ok { + for _, tuRaw := range tus { + if tu, ok := tuRaw.(map[string]interface{}); ok { + toolUses = append(toolUses, tu) + } + } + } + + // Handle text content with thinking mode support + if contentDelta != "" { + // NOTE: Duplicate content filtering was removed because it incorrectly + // filtered out legitimate repeated content (like consecutive newlines "\n\n"). + // Streaming naturally can have identical chunks that are valid content. + + outputLen += len(contentDelta) + // Accumulate content for streaming token calculation + accumulatedContent.WriteString(contentDelta) + + // Real-time usage estimation: Check if we should send a usage update + // This helps clients track context usage during long thinking sessions + shouldSendUsageUpdate := false + if accumulatedContent.Len()-lastUsageUpdateLen >= usageUpdateCharThreshold { + shouldSendUsageUpdate = true + } else if time.Since(lastUsageUpdateTime) >= usageUpdateTimeInterval && accumulatedContent.Len() > lastUsageUpdateLen { + shouldSendUsageUpdate = true + } + + if shouldSendUsageUpdate { + // Calculate current output tokens using tiktoken + var currentOutputTokens int64 + if enc, encErr := getTokenizer(model); encErr == nil { + if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil { + currentOutputTokens = int64(tokenCount) + } + } + // Fallback to character estimation if tiktoken fails + if currentOutputTokens == 0 { + currentOutputTokens = int64(accumulatedContent.Len() / 4) + if currentOutputTokens == 0 { + currentOutputTokens = 1 + } + } + + // Only send update if token count has changed significantly (at least 10 tokens) + if currentOutputTokens > lastReportedOutputTokens+10 { + // Send ping event with usage information + // This is a non-blocking update that clients can optionally process + pingEvent := kiroclaude.BuildClaudePingEventWithUsage(totalUsage.InputTokens, currentOutputTokens) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, pingEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + lastReportedOutputTokens = currentOutputTokens + log.Debugf("kiro: sent real-time usage update - input: %d, output: %d (accumulated: %d chars)", + totalUsage.InputTokens, currentOutputTokens, accumulatedContent.Len()) + } + + lastUsageUpdateLen = accumulatedContent.Len() + lastUsageUpdateTime = time.Now() + } + + // Process content with thinking tag detection - based on amq2api implementation + // This handles and tags that may span across chunks + remaining := contentDelta + + // If we have pending start tag chars from previous chunk, prepend them + if pendingStartTagChars > 0 { + remaining = kirocommon.ThinkingStartTag[:pendingStartTagChars] + remaining + pendingStartTagChars = 0 + } + + // If we have pending end tag chars from previous chunk, prepend them + if pendingEndTagChars > 0 { + remaining = kirocommon.ThinkingEndTag[:pendingEndTagChars] + remaining + pendingEndTagChars = 0 + } + + for len(remaining) > 0 { + // CRITICAL FIX: Only parse tags when thinking mode was enabled in the request. + // When thinking is NOT enabled, tags in responses should be treated as + // regular text content, not as thinking blocks. This prevents normal text content + // from being incorrectly parsed as thinking when the model outputs tags + // without the user requesting thinking mode. + if !thinkingEnabled { + // Thinking not enabled - emit all content as regular text without parsing tags + if remaining != "" { + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + claudeEvent := kiroclaude.BuildClaudeStreamEvent(remaining, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + break // Exit the for loop - all content processed as text + } + + // HEURISTIC FIX: Track code block and inline code state to avoid parsing tags + // inside code contexts. When the model outputs code examples containing these tags, + // they should be treated as text. + if !inThinkBlock { + // Check for inline code backticks first (higher priority than code fences) + // This handles cases like `` being discussed in text + backtickIdx := strings.Index(remaining, kirocommon.InlineCodeMarker) + thinkingIdx := strings.Index(remaining, kirocommon.ThinkingStartTag) + + // If backtick comes before thinking tag, handle inline code + if backtickIdx >= 0 && (thinkingIdx < 0 || backtickIdx < thinkingIdx) { + if inInlineCode { + // Closing backtick - emit content up to and including backtick, exit inline code + textToEmit := remaining[:backtickIdx+1] + if textToEmit != "" { + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + claudeEvent := kiroclaude.BuildClaudeStreamEvent(textToEmit, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + remaining = remaining[backtickIdx+1:] + inInlineCode = false + continue + } else { + // Opening backtick - emit content before backtick, enter inline code + textToEmit := remaining[:backtickIdx+1] + if textToEmit != "" { + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + claudeEvent := kiroclaude.BuildClaudeStreamEvent(textToEmit, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + remaining = remaining[backtickIdx+1:] + inInlineCode = true + continue + } + } + + // If inside inline code, emit all content as text (don't parse thinking tags) + if inInlineCode { + if remaining != "" { + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + claudeEvent := kiroclaude.BuildClaudeStreamEvent(remaining, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + break // Exit loop - remaining content is inside inline code + } + + // Check for code fence markers (``` or ~~~) to toggle code block state + fenceIdx := strings.Index(remaining, kirocommon.CodeFenceMarker) + altFenceIdx := strings.Index(remaining, kirocommon.AltCodeFenceMarker) + + // Find the earliest fence marker + earliestFenceIdx := -1 + earliestFenceType := "" + if fenceIdx >= 0 && (altFenceIdx < 0 || fenceIdx < altFenceIdx) { + earliestFenceIdx = fenceIdx + earliestFenceType = kirocommon.CodeFenceMarker + } else if altFenceIdx >= 0 { + earliestFenceIdx = altFenceIdx + earliestFenceType = kirocommon.AltCodeFenceMarker + } + + if earliestFenceIdx >= 0 { + // Check if this fence comes before any thinking tag + thinkingIdx := strings.Index(remaining, kirocommon.ThinkingStartTag) + if inCodeBlock { + // Inside code block - check if this fence closes it + if earliestFenceType == codeFenceType { + // This fence closes the code block + // Emit content up to and including the fence as text + textToEmit := remaining[:earliestFenceIdx+len(earliestFenceType)] + if textToEmit != "" { + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + claudeEvent := kiroclaude.BuildClaudeStreamEvent(textToEmit, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + remaining = remaining[earliestFenceIdx+len(earliestFenceType):] + inCodeBlock = false + codeFenceType = "" + log.Debugf("kiro: exited code block") + continue + } + } else if thinkingIdx < 0 || earliestFenceIdx < thinkingIdx { + // Not in code block, and fence comes before thinking tag (or no thinking tag) + // Emit content up to and including the fence as text, then enter code block + textToEmit := remaining[:earliestFenceIdx+len(earliestFenceType)] + if textToEmit != "" { + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + claudeEvent := kiroclaude.BuildClaudeStreamEvent(textToEmit, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + remaining = remaining[earliestFenceIdx+len(earliestFenceType):] + inCodeBlock = true + codeFenceType = earliestFenceType + log.Debugf("kiro: entered code block with fence: %s", earliestFenceType) + continue + } + } + + // If inside code block, emit all content as text (don't parse thinking tags) + if inCodeBlock { + if remaining != "" { + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + claudeEvent := kiroclaude.BuildClaudeStreamEvent(remaining, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + break // Exit loop - all remaining content is inside code block + } + } + + if inThinkBlock { + // Inside thinking block - look for end tag + // CRITICAL FIX: Skip tags that are not the real end tag + // This prevents false positives when thinking content discusses these tags + // Pass current code block/inline code state for accurate detection + endIdx := findRealThinkingEndTag(remaining, inCodeBlock, inInlineCode) + + if endIdx >= 0 { + // Found end tag - emit any content before end tag, then close block + thinkContent := remaining[:endIdx] + if thinkContent != "" { + // TRUE STREAMING: Emit thinking content immediately + // Start thinking block if not open + if !isThinkingBlockOpen { + contentBlockIndex++ + thinkingBlockIndex = contentBlockIndex + isThinkingBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + // Send thinking delta immediately + thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(thinkContent, thinkingBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + // Note: Partial tag handling is done via pendingEndTagChars + // When the next chunk arrives, the partial tag will be reconstructed + + // Close thinking block + if isThinkingBlockOpen { + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(thinkingBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isThinkingBlockOpen = false + } + + inThinkBlock = false + thinkingBlockCompleted = true // Mark that we've completed a thinking block + remaining = remaining[endIdx+len(kirocommon.ThinkingEndTag):] + log.Debugf("kiro: exited thinking block, subsequent tags will be treated as text") + } else { + // No end tag found - TRUE STREAMING: emit content immediately + // Only save potential partial tag length for next iteration + pendingEnd := kiroclaude.PendingTagSuffix(remaining, kirocommon.ThinkingEndTag) + + // Calculate content to emit immediately (excluding potential partial tag) + var contentToEmit string + if pendingEnd > 0 { + contentToEmit = remaining[:len(remaining)-pendingEnd] + // Save partial tag length for next iteration (will be reconstructed from thinkingEndTag) + pendingEndTagChars = pendingEnd + } else { + contentToEmit = remaining + } + + // TRUE STREAMING: Emit thinking content immediately + if contentToEmit != "" { + // Start thinking block if not open + if !isThinkingBlockOpen { + contentBlockIndex++ + thinkingBlockIndex = contentBlockIndex + isThinkingBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(thinkingBlockIndex, "thinking", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + // Send thinking delta immediately - TRUE STREAMING! + thinkingEvent := kiroclaude.BuildClaudeThinkingDeltaEvent(contentToEmit, thinkingBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, thinkingEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + remaining = "" + } + } else { + // Outside thinking block - look for start tag + // CRITICAL FIX: Only parse tags at the very beginning of the response + // or if we haven't completed a thinking block yet. + // After a thinking block is completed, subsequent tags are likely + // discussion text (e.g., "Kiro returns `` tags") and should NOT be parsed. + startIdx := -1 + if !thinkingBlockCompleted && !hasSeenNonThinkingContent { + startIdx = strings.Index(remaining, kirocommon.ThinkingStartTag) + // If there's non-whitespace content before the tag, it's not a real thinking block + if startIdx > 0 { + textBefore := remaining[:startIdx] + if strings.TrimSpace(textBefore) != "" { + // There's real content before the tag - this is discussion text, not thinking + hasSeenNonThinkingContent = true + startIdx = -1 + log.Debugf("kiro: found tag after non-whitespace content, treating as text") + } + } + } + if startIdx >= 0 { + // Found start tag - emit text before it and switch to thinking mode + textBefore := remaining[:startIdx] + if textBefore != "" { + // Only whitespace before thinking tag is allowed + // Start text content block if needed + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + claudeEvent := kiroclaude.BuildClaudeStreamEvent(textBefore, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + // Close text block before starting thinking block + if isTextBlockOpen { + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isTextBlockOpen = false + } + + inThinkBlock = true + remaining = remaining[startIdx+len(kirocommon.ThinkingStartTag):] + log.Debugf("kiro: entered thinking block") + } else { + // No start tag found - check for partial start tag at buffer end + // Only check for partial tags if we haven't completed a thinking block yet + pendingStart := 0 + if !thinkingBlockCompleted && !hasSeenNonThinkingContent { + pendingStart = kiroclaude.PendingTagSuffix(remaining, kirocommon.ThinkingStartTag) + } + if pendingStart > 0 { + // Emit text except potential partial tag + textToEmit := remaining[:len(remaining)-pendingStart] + if textToEmit != "" { + // Mark that we've seen non-thinking content + if strings.TrimSpace(textToEmit) != "" { + hasSeenNonThinkingContent = true + } + // Start text content block if needed + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + claudeEvent := kiroclaude.BuildClaudeStreamEvent(textToEmit, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + pendingStartTagChars = pendingStart + remaining = "" + } else { + // No partial tag - emit all as text + if remaining != "" { + // Mark that we've seen non-thinking content + if strings.TrimSpace(remaining) != "" { + hasSeenNonThinkingContent = true + } + // Start text content block if needed + if !isTextBlockOpen { + contentBlockIndex++ + isTextBlockOpen = true + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "text", "", "") + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + claudeEvent := kiroclaude.BuildClaudeStreamEvent(remaining, contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, claudeEvent, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + remaining = "" + } + } + } + } + } + + // Handle tool uses in response (with deduplication) + for _, tu := range toolUses { + toolUseID := kirocommon.GetString(tu, "toolUseId") + + // Check for duplicate + if processedIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate tool use in stream: %s", toolUseID) + continue + } + processedIDs[toolUseID] = true + + hasToolUses = true + // Close text block if open before starting tool_use block + if isTextBlockOpen && contentBlockIndex >= 0 { + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isTextBlockOpen = false + } + + // Emit tool_use content block + contentBlockIndex++ + toolName := kirocommon.GetString(tu, "name") + + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", toolUseID, toolName) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Send input_json_delta with the tool input + if input, ok := tu["input"].(map[string]interface{}); ok { + inputJSON, err := json.Marshal(input) + if err != nil { + log.Debugf("kiro: failed to marshal tool input: %v", err) + // Don't continue - still need to close the block + } else { + inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + } + + // Close tool_use block (always close even if input marshal failed) + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + case "toolUseEvent": + // Handle dedicated tool use events with input buffering + completedToolUses, newState := kiroclaude.ProcessToolUseEvent(event, currentToolUse, processedIDs) + currentToolUse = newState + + // Emit completed tool uses + for _, tu := range completedToolUses { + hasToolUses = true + + // Close text block if open + if isTextBlockOpen && contentBlockIndex >= 0 { + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + isTextBlockOpen = false + } + + contentBlockIndex++ + + blockStart := kiroclaude.BuildClaudeContentBlockStartEvent(contentBlockIndex, "tool_use", tu.ToolUseID, tu.Name) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStart, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + if tu.Input != nil { + inputJSON, err := json.Marshal(tu.Input) + if err != nil { + log.Debugf("kiro: failed to marshal tool input in toolUseEvent: %v", err) + } else { + inputDelta := kiroclaude.BuildClaudeInputJsonDeltaEvent(string(inputJSON), contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, inputDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + } + + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + case "supplementaryWebLinksEvent": + if inputTokens, ok := event["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + + case "messageMetadataEvent": + // Handle message metadata events which may contain token counts + if metadata, ok := event["messageMetadataEvent"].(map[string]interface{}); ok { + if inputTokens, ok := metadata["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + log.Debugf("kiro: streamToChannel found inputTokens in messageMetadataEvent: %d", totalUsage.InputTokens) + } + if outputTokens, ok := metadata["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + log.Debugf("kiro: streamToChannel found outputTokens in messageMetadataEvent: %d", totalUsage.OutputTokens) + } + if totalTokens, ok := metadata["totalTokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + log.Debugf("kiro: streamToChannel found totalTokens in messageMetadataEvent: %d", totalUsage.TotalTokens) + } + } + + case "usageEvent", "usage": + // Handle dedicated usage events + if inputTokens, ok := event["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + log.Debugf("kiro: streamToChannel found inputTokens in usageEvent: %d", totalUsage.InputTokens) + } + if outputTokens, ok := event["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + log.Debugf("kiro: streamToChannel found outputTokens in usageEvent: %d", totalUsage.OutputTokens) + } + if totalTokens, ok := event["totalTokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + log.Debugf("kiro: streamToChannel found totalTokens in usageEvent: %d", totalUsage.TotalTokens) + } + // Also check nested usage object + if usageObj, ok := event["usage"].(map[string]interface{}); ok { + if inputTokens, ok := usageObj["input_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := usageObj["output_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + if totalTokens, ok := usageObj["total_tokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + } + log.Debugf("kiro: streamToChannel found usage object: input=%d, output=%d, total=%d", + totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) + } + + case "metricsEvent": + // Handle metrics events which may contain usage data + if metrics, ok := event["metricsEvent"].(map[string]interface{}); ok { + if inputTokens, ok := metrics["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := metrics["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + log.Debugf("kiro: streamToChannel found metricsEvent: input=%d, output=%d", + totalUsage.InputTokens, totalUsage.OutputTokens) + } + } + + // Check nested usage event + if usageEvent, ok := event["supplementaryWebLinksEvent"].(map[string]interface{}); ok { + if inputTokens, ok := usageEvent["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + if outputTokens, ok := usageEvent["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + } + + // Check for direct token fields in any event (fallback) + if totalUsage.InputTokens == 0 { + if inputTokens, ok := event["inputTokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + log.Debugf("kiro: streamToChannel found direct inputTokens: %d", totalUsage.InputTokens) + } + } + if totalUsage.OutputTokens == 0 { + if outputTokens, ok := event["outputTokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + log.Debugf("kiro: streamToChannel found direct outputTokens: %d", totalUsage.OutputTokens) + } + } + + // Check for usage object in any event (OpenAI format) + if totalUsage.InputTokens == 0 || totalUsage.OutputTokens == 0 { + if usageObj, ok := event["usage"].(map[string]interface{}); ok { + if totalUsage.InputTokens == 0 { + if inputTokens, ok := usageObj["input_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } else if inputTokens, ok := usageObj["prompt_tokens"].(float64); ok { + totalUsage.InputTokens = int64(inputTokens) + } + } + if totalUsage.OutputTokens == 0 { + if outputTokens, ok := usageObj["output_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } else if outputTokens, ok := usageObj["completion_tokens"].(float64); ok { + totalUsage.OutputTokens = int64(outputTokens) + } + } + if totalUsage.TotalTokens == 0 { + if totalTokens, ok := usageObj["total_tokens"].(float64); ok { + totalUsage.TotalTokens = int64(totalTokens) + } + } + log.Debugf("kiro: streamToChannel found usage object (fallback): input=%d, output=%d, total=%d", + totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) + } + } + } + + // Close content block if open + if isTextBlockOpen && contentBlockIndex >= 0 { + blockStop := kiroclaude.BuildClaudeContentBlockStopEvent(contentBlockIndex) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, blockStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + } + + // Streaming token calculation - calculate output tokens from accumulated content + // Only use local estimation if server didn't provide usage (server-side usage takes priority) + if totalUsage.OutputTokens == 0 && accumulatedContent.Len() > 0 { + // Try to use tiktoken for accurate counting + if enc, err := getTokenizer(model); err == nil { + if tokenCount, countErr := enc.Count(accumulatedContent.String()); countErr == nil { + totalUsage.OutputTokens = int64(tokenCount) + log.Debugf("kiro: streamToChannel calculated output tokens using tiktoken: %d", totalUsage.OutputTokens) + } else { + // Fallback on count error: estimate from character count + totalUsage.OutputTokens = int64(accumulatedContent.Len() / 4) + if totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + log.Debugf("kiro: streamToChannel tiktoken count failed, estimated from chars: %d", totalUsage.OutputTokens) + } + } else { + // Fallback: estimate from character count (roughly 4 chars per token) + totalUsage.OutputTokens = int64(accumulatedContent.Len() / 4) + if totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + log.Debugf("kiro: streamToChannel estimated output tokens from chars: %d (content len: %d)", totalUsage.OutputTokens, accumulatedContent.Len()) + } + } else if totalUsage.OutputTokens == 0 && outputLen > 0 { + // Legacy fallback using outputLen + totalUsage.OutputTokens = int64(outputLen / 4) + if totalUsage.OutputTokens == 0 { + totalUsage.OutputTokens = 1 + } + } + + // Use contextUsagePercentage to calculate more accurate input tokens + // Kiro model has 200k max context, contextUsagePercentage represents the percentage used + // Formula: input_tokens = contextUsagePercentage * 200000 / 100 + // Note: The effective input context is ~170k (200k - 30k reserved for output) + if upstreamContextPercentage > 0 { + // Calculate input tokens from context percentage + // Using 200k as the base since that's what Kiro reports against + calculatedInputTokens := int64(upstreamContextPercentage * 200000 / 100) + + // Only use calculated value if it's significantly different from local estimate + // This provides more accurate token counts based on upstream data + if calculatedInputTokens > 0 { + localEstimate := totalUsage.InputTokens + totalUsage.InputTokens = calculatedInputTokens + log.Infof("kiro: using contextUsagePercentage (%.2f%%) to calculate input tokens: %d (local estimate was: %d)", + upstreamContextPercentage, calculatedInputTokens, localEstimate) + } + } + + totalUsage.TotalTokens = totalUsage.InputTokens + totalUsage.OutputTokens + + // Log upstream usage information if received + if hasUpstreamUsage { + log.Infof("kiro: upstream usage - credits: %.4f, context: %.2f%%, final tokens - input: %d, output: %d, total: %d", + upstreamCreditUsage, upstreamContextPercentage, + totalUsage.InputTokens, totalUsage.OutputTokens, totalUsage.TotalTokens) + } + + // Determine stop reason: prefer upstream, then detect tool_use, default to end_turn + stopReason := upstreamStopReason + if stopReason == "" { + if hasToolUses { + stopReason = "tool_use" + log.Debugf("kiro: streamToChannel using fallback stop_reason: tool_use") + } else { + stopReason = "end_turn" + log.Debugf("kiro: streamToChannel using fallback stop_reason: end_turn") + } + } + + // Log warning if response was truncated due to max_tokens + if stopReason == "max_tokens" { + log.Warnf("kiro: response truncated due to max_tokens limit (streamToChannel)") + } + + // Send message_delta event + msgDelta := kiroclaude.BuildClaudeMessageDeltaEvent(stopReason, totalUsage) + sseData := sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgDelta, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + + // Send message_stop event separately + msgStop := kiroclaude.BuildClaudeMessageStopOnlyEvent() + sseData = sdktranslator.TranslateStream(ctx, sdktranslator.FromString("kiro"), targetFormat, model, originalReq, claudeBody, msgStop, &translatorParam) + for _, chunk := range sseData { + if chunk != "" { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunk + "\n\n")} + } + } + // reporter.publish is called via defer +} + +// NOTE: Claude SSE event builders moved to internal/translator/kiro/claude/kiro_claude_stream.go +// The executor now uses kiroclaude.BuildClaude*Event() functions instead + +// CountTokens counts tokens locally using tiktoken since Kiro API doesn't expose a token counting endpoint. +// This provides approximate token counts for client requests. +func (e *KiroExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + // Use tiktoken for local token counting + enc, err := getTokenizer(req.Model) + if err != nil { + log.Warnf("kiro: CountTokens failed to get tokenizer: %v, falling back to estimate", err) + // Fallback: estimate from payload size (roughly 4 chars per token) + estimatedTokens := len(req.Payload) / 4 + if estimatedTokens == 0 && len(req.Payload) > 0 { + estimatedTokens = 1 + } + return cliproxyexecutor.Response{ + Payload: []byte(fmt.Sprintf(`{"count":%d}`, estimatedTokens)), + }, nil + } + + // Try to count tokens from the request payload + var totalTokens int64 + + // Try OpenAI chat format first + if tokens, countErr := countOpenAIChatTokens(enc, req.Payload); countErr == nil && tokens > 0 { + totalTokens = tokens + log.Debugf("kiro: CountTokens counted %d tokens using OpenAI chat format", totalTokens) + } else { + // Fallback: count raw payload tokens + if tokenCount, countErr := enc.Count(string(req.Payload)); countErr == nil { + totalTokens = int64(tokenCount) + log.Debugf("kiro: CountTokens counted %d tokens from raw payload", totalTokens) + } else { + // Final fallback: estimate from payload size + totalTokens = int64(len(req.Payload) / 4) + if totalTokens == 0 && len(req.Payload) > 0 { + totalTokens = 1 + } + log.Debugf("kiro: CountTokens estimated %d tokens from payload size", totalTokens) + } + } + + return cliproxyexecutor.Response{ + Payload: []byte(fmt.Sprintf(`{"count":%d}`, totalTokens)), + }, nil +} + +// Refresh refreshes the Kiro OAuth token. +// Supports both AWS Builder ID (SSO OIDC) and Google OAuth (social login). +// Uses mutex to prevent race conditions when multiple concurrent requests try to refresh. +func (e *KiroExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + // Serialize token refresh operations to prevent race conditions + e.refreshMu.Lock() + defer e.refreshMu.Unlock() + + var authID string + if auth != nil { + authID = auth.ID + } else { + authID = "" + } + log.Debugf("kiro executor: refresh called for auth %s", authID) + if auth == nil { + return nil, fmt.Errorf("kiro executor: auth is nil") + } + + // Double-check: After acquiring lock, verify token still needs refresh + // Another goroutine may have already refreshed while we were waiting + // NOTE: This check has a design limitation - it reads from the auth object passed in, + // not from persistent storage. If another goroutine returns a new Auth object (via Clone), + // this check won't see those updates. The mutex still prevents truly concurrent refreshes, + // but queued goroutines may still attempt redundant refreshes. This is acceptable as + // the refresh operation is idempotent and the extra API calls are infrequent. + if auth.Metadata != nil { + if lastRefresh, ok := auth.Metadata["last_refresh"].(string); ok { + if refreshTime, err := time.Parse(time.RFC3339, lastRefresh); err == nil { + // If token was refreshed within the last 30 seconds, skip refresh + if time.Since(refreshTime) < 30*time.Second { + log.Debugf("kiro executor: token was recently refreshed by another goroutine, skipping") + return auth, nil + } + } + } + // Also check if expires_at is now in the future with sufficient buffer + if expiresAt, ok := auth.Metadata["expires_at"].(string); ok { + if expTime, err := time.Parse(time.RFC3339, expiresAt); err == nil { + // If token expires more than 5 minutes from now, it's still valid + if time.Until(expTime) > 5*time.Minute { + log.Debugf("kiro executor: token is still valid (expires in %v), skipping refresh", time.Until(expTime)) + // CRITICAL FIX: Set NextRefreshAfter to prevent frequent refresh checks + // Without this, shouldRefresh() will return true again in 5 seconds + updated := auth.Clone() + // Set next refresh to 5 minutes before expiry, or at least 30 seconds from now + nextRefresh := expTime.Add(-5 * time.Minute) + minNextRefresh := time.Now().Add(30 * time.Second) + if nextRefresh.Before(minNextRefresh) { + nextRefresh = minNextRefresh + } + updated.NextRefreshAfter = nextRefresh + log.Debugf("kiro executor: setting NextRefreshAfter to %v (in %v)", nextRefresh.Format(time.RFC3339), time.Until(nextRefresh)) + return updated, nil + } + } + } + } + + var refreshToken string + var clientID, clientSecret string + var authMethod string + + if auth.Metadata != nil { + if rt, ok := auth.Metadata["refresh_token"].(string); ok { + refreshToken = rt + } + if cid, ok := auth.Metadata["client_id"].(string); ok { + clientID = cid + } + if cs, ok := auth.Metadata["client_secret"].(string); ok { + clientSecret = cs + } + if am, ok := auth.Metadata["auth_method"].(string); ok { + authMethod = am + } + } + + if refreshToken == "" { + return nil, fmt.Errorf("kiro executor: refresh token not found") + } + + var tokenData *kiroauth.KiroTokenData + var err error + + // Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint + if clientID != "" && clientSecret != "" && authMethod == "builder-id" { + log.Debugf("kiro executor: using SSO OIDC refresh for AWS Builder ID") + ssoClient := kiroauth.NewSSOOIDCClient(e.cfg) + tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) + } else { + log.Debugf("kiro executor: using Kiro OAuth refresh endpoint") + oauth := kiroauth.NewKiroOAuth(e.cfg) + tokenData, err = oauth.RefreshToken(ctx, refreshToken) + } + + if err != nil { + return nil, fmt.Errorf("kiro executor: token refresh failed: %w", err) + } + + updated := auth.Clone() + now := time.Now() + updated.UpdatedAt = now + updated.LastRefreshedAt = now + + if updated.Metadata == nil { + updated.Metadata = make(map[string]any) + } + updated.Metadata["access_token"] = tokenData.AccessToken + updated.Metadata["refresh_token"] = tokenData.RefreshToken + updated.Metadata["expires_at"] = tokenData.ExpiresAt + updated.Metadata["last_refresh"] = now.Format(time.RFC3339) + if tokenData.ProfileArn != "" { + updated.Metadata["profile_arn"] = tokenData.ProfileArn + } + if tokenData.AuthMethod != "" { + updated.Metadata["auth_method"] = tokenData.AuthMethod + } + if tokenData.Provider != "" { + updated.Metadata["provider"] = tokenData.Provider + } + // Preserve client credentials for future refreshes (AWS Builder ID) + if tokenData.ClientID != "" { + updated.Metadata["client_id"] = tokenData.ClientID + } + if tokenData.ClientSecret != "" { + updated.Metadata["client_secret"] = tokenData.ClientSecret + } + + if updated.Attributes == nil { + updated.Attributes = make(map[string]string) + } + updated.Attributes["access_token"] = tokenData.AccessToken + if tokenData.ProfileArn != "" { + updated.Attributes["profile_arn"] = tokenData.ProfileArn + } + + // NextRefreshAfter is aligned with RefreshLead (5min) + if expiresAt, parseErr := time.Parse(time.RFC3339, tokenData.ExpiresAt); parseErr == nil { + updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute) + } + + log.Infof("kiro executor: token refreshed successfully, expires at %s", tokenData.ExpiresAt) + return updated, nil +} + +// isTokenExpired checks if a JWT access token has expired. +// Returns true if the token is expired or cannot be parsed. +func (e *KiroExecutor) isTokenExpired(accessToken string) bool { + if accessToken == "" { + return true + } + + // JWT tokens have 3 parts separated by dots + parts := strings.Split(accessToken, ".") + if len(parts) != 3 { + // Not a JWT token, assume not expired + return false + } + + // Decode the payload (second part) + // JWT uses base64url encoding without padding (RawURLEncoding) + payload := parts[1] + decoded, err := base64.RawURLEncoding.DecodeString(payload) + if err != nil { + // Try with padding added as fallback + switch len(payload) % 4 { + case 2: + payload += "==" + case 3: + payload += "=" + } + decoded, err = base64.URLEncoding.DecodeString(payload) + if err != nil { + log.Debugf("kiro: failed to decode JWT payload: %v", err) + return false + } + } + + var claims struct { + Exp int64 `json:"exp"` + } + if err := json.Unmarshal(decoded, &claims); err != nil { + log.Debugf("kiro: failed to parse JWT claims: %v", err) + return false + } + + if claims.Exp == 0 { + // No expiration claim, assume not expired + return false + } + + expTime := time.Unix(claims.Exp, 0) + now := time.Now() + + // Consider token expired if it expires within 1 minute (buffer for clock skew) + isExpired := now.After(expTime) || expTime.Sub(now) < time.Minute + if isExpired { + log.Debugf("kiro: token expired at %s (now: %s)", expTime.Format(time.RFC3339), now.Format(time.RFC3339)) + } + + return isExpired +} + +// NOTE: Message merging functions moved to internal/translator/kiro/common/message_merge.go +// NOTE: Tool calling support functions moved to internal/translator/kiro/claude/kiro_claude_tools.go +// The executor now uses kiroclaude.* and kirocommon.* functions instead diff --git a/internal/runtime/executor/proxy_helpers.go b/internal/runtime/executor/proxy_helpers.go index ab0f626a..8998eb23 100644 --- a/internal/runtime/executor/proxy_helpers.go +++ b/internal/runtime/executor/proxy_helpers.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" "strings" + "sync" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" @@ -14,11 +15,19 @@ import ( "golang.org/x/net/proxy" ) +// httpClientCache caches HTTP clients by proxy URL to enable connection reuse +var ( + httpClientCache = make(map[string]*http.Client) + httpClientCacheMutex sync.RWMutex +) + // newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority: // 1. Use auth.ProxyURL if configured (highest priority) // 2. Use cfg.ProxyURL if auth proxy is not configured // 3. Use RoundTripper from context if neither are configured // +// This function caches HTTP clients by proxy URL to enable TCP/TLS connection reuse. +// // Parameters: // - ctx: The context containing optional RoundTripper // - cfg: The application configuration @@ -28,11 +37,6 @@ import ( // Returns: // - *http.Client: An HTTP client with configured proxy or transport func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client { - httpClient := &http.Client{} - if timeout > 0 { - httpClient.Timeout = timeout - } - // Priority 1: Use auth.ProxyURL if configured var proxyURL string if auth != nil { @@ -44,11 +48,39 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip proxyURL = strings.TrimSpace(cfg.ProxyURL) } + // Build cache key from proxy URL (empty string for no proxy) + cacheKey := proxyURL + + // Check cache first + httpClientCacheMutex.RLock() + if cachedClient, ok := httpClientCache[cacheKey]; ok { + httpClientCacheMutex.RUnlock() + // Return a wrapper with the requested timeout but shared transport + if timeout > 0 { + return &http.Client{ + Transport: cachedClient.Transport, + Timeout: timeout, + } + } + return cachedClient + } + httpClientCacheMutex.RUnlock() + + // Create new client + httpClient := &http.Client{} + if timeout > 0 { + httpClient.Timeout = timeout + } + // If we have a proxy URL configured, set up the transport if proxyURL != "" { transport := buildProxyTransport(proxyURL) if transport != nil { httpClient.Transport = transport + // Cache the client + httpClientCacheMutex.Lock() + httpClientCache[cacheKey] = httpClient + httpClientCacheMutex.Unlock() return httpClient } // If proxy setup failed, log and fall through to context RoundTripper @@ -60,6 +92,13 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip httpClient.Transport = rt } + // Cache the client for no-proxy case + if proxyURL == "" { + httpClientCacheMutex.Lock() + httpClientCache[cacheKey] = httpClient + httpClientCacheMutex.Unlock() + } + return httpClient } diff --git a/internal/runtime/executor/token_helpers.go b/internal/runtime/executor/token_helpers.go index f4236f9b..3dd2a2b5 100644 --- a/internal/runtime/executor/token_helpers.go +++ b/internal/runtime/executor/token_helpers.go @@ -2,43 +2,107 @@ package executor import ( "fmt" + "regexp" + "strconv" "strings" + "sync" "github.com/tidwall/gjson" "github.com/tiktoken-go/tokenizer" ) +// tokenizerCache stores tokenizer instances to avoid repeated creation +var tokenizerCache sync.Map + +// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models +// where tiktoken may not accurately estimate token counts (e.g., Claude models) +type TokenizerWrapper struct { + Codec tokenizer.Codec + AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates +} + +// Count returns the token count with adjustment factor applied +func (tw *TokenizerWrapper) Count(text string) (int, error) { + count, err := tw.Codec.Count(text) + if err != nil { + return 0, err + } + if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 { + return int(float64(count) * tw.AdjustmentFactor), nil + } + return count, nil +} + +// getTokenizer returns a cached tokenizer for the given model. +// This improves performance by avoiding repeated tokenizer creation. +func getTokenizer(model string) (*TokenizerWrapper, error) { + // Check cache first + if cached, ok := tokenizerCache.Load(model); ok { + return cached.(*TokenizerWrapper), nil + } + + // Cache miss, create new tokenizer + wrapper, err := tokenizerForModel(model) + if err != nil { + return nil, err + } + + // Store in cache (use LoadOrStore to handle race conditions) + actual, _ := tokenizerCache.LoadOrStore(model, wrapper) + return actual.(*TokenizerWrapper), nil +} + // tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id. -func tokenizerForModel(model string) (tokenizer.Codec, error) { +// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate. +func tokenizerForModel(model string) (*TokenizerWrapper, error) { sanitized := strings.ToLower(strings.TrimSpace(model)) + + // Claude models use cl100k_base with 1.1 adjustment factor + // because tiktoken may underestimate Claude's actual token count + if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") { + enc, err := tokenizer.Get(tokenizer.Cl100kBase) + if err != nil { + return nil, err + } + return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil + } + + var enc tokenizer.Codec + var err error + switch { case sanitized == "": - return tokenizer.Get(tokenizer.Cl100kBase) + enc, err = tokenizer.Get(tokenizer.Cl100kBase) case strings.HasPrefix(sanitized, "gpt-5"): - return tokenizer.ForModel(tokenizer.GPT5) + enc, err = tokenizer.ForModel(tokenizer.GPT5) case strings.HasPrefix(sanitized, "gpt-5.1"): - return tokenizer.ForModel(tokenizer.GPT5) + enc, err = tokenizer.ForModel(tokenizer.GPT5) case strings.HasPrefix(sanitized, "gpt-4.1"): - return tokenizer.ForModel(tokenizer.GPT41) + enc, err = tokenizer.ForModel(tokenizer.GPT41) case strings.HasPrefix(sanitized, "gpt-4o"): - return tokenizer.ForModel(tokenizer.GPT4o) + enc, err = tokenizer.ForModel(tokenizer.GPT4o) case strings.HasPrefix(sanitized, "gpt-4"): - return tokenizer.ForModel(tokenizer.GPT4) + enc, err = tokenizer.ForModel(tokenizer.GPT4) case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"): - return tokenizer.ForModel(tokenizer.GPT35Turbo) + enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo) case strings.HasPrefix(sanitized, "o1"): - return tokenizer.ForModel(tokenizer.O1) + enc, err = tokenizer.ForModel(tokenizer.O1) case strings.HasPrefix(sanitized, "o3"): - return tokenizer.ForModel(tokenizer.O3) + enc, err = tokenizer.ForModel(tokenizer.O3) case strings.HasPrefix(sanitized, "o4"): - return tokenizer.ForModel(tokenizer.O4Mini) + enc, err = tokenizer.ForModel(tokenizer.O4Mini) default: - return tokenizer.Get(tokenizer.O200kBase) + enc, err = tokenizer.Get(tokenizer.O200kBase) } + + if err != nil { + return nil, err + } + return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil } // countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads. -func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) { +func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) { if enc == nil { return 0, fmt.Errorf("encoder is nil") } @@ -62,11 +126,206 @@ func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) { return 0, nil } + // Count text tokens count, err := enc.Count(joined) if err != nil { return 0, err } - return int64(count), nil + + // Extract and add image tokens from placeholders + imageTokens := extractImageTokens(joined) + + return int64(count) + int64(imageTokens), nil +} + +// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads. +// This handles Claude's message format with system, messages, and tools. +// Image tokens are estimated based on image dimensions when available. +func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) { + if enc == nil { + return 0, fmt.Errorf("encoder is nil") + } + if len(payload) == 0 { + return 0, nil + } + + root := gjson.ParseBytes(payload) + segments := make([]string, 0, 32) + + // Collect system prompt (can be string or array of content blocks) + collectClaudeSystem(root.Get("system"), &segments) + + // Collect messages + collectClaudeMessages(root.Get("messages"), &segments) + + // Collect tools + collectClaudeTools(root.Get("tools"), &segments) + + joined := strings.TrimSpace(strings.Join(segments, "\n")) + if joined == "" { + return 0, nil + } + + // Count text tokens + count, err := enc.Count(joined) + if err != nil { + return 0, err + } + + // Extract and add image tokens from placeholders + imageTokens := extractImageTokens(joined) + + return int64(count) + int64(imageTokens), nil +} + +// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens +var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`) + +// extractImageTokens extracts image token estimates from placeholder text. +// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count. +func extractImageTokens(text string) int { + matches := imageTokenPattern.FindAllStringSubmatch(text, -1) + total := 0 + for _, match := range matches { + if len(match) > 1 { + if tokens, err := strconv.Atoi(match[1]); err == nil { + total += tokens + } + } + } + return total +} + +// estimateImageTokens calculates estimated tokens for an image based on dimensions. +// Based on Claude's image token calculation: tokens ≈ (width * height) / 750 +// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images). +func estimateImageTokens(width, height float64) int { + if width <= 0 || height <= 0 { + // No valid dimensions, use default estimate (medium-sized image) + return 1000 + } + + tokens := int(width * height / 750) + + // Apply bounds + if tokens < 85 { + tokens = 85 + } + if tokens > 1590 { + tokens = 1590 + } + + return tokens +} + +// collectClaudeSystem extracts text from Claude's system field. +// System can be a string or an array of content blocks. +func collectClaudeSystem(system gjson.Result, segments *[]string) { + if !system.Exists() { + return + } + if system.Type == gjson.String { + addIfNotEmpty(segments, system.String()) + return + } + if system.IsArray() { + system.ForEach(func(_, block gjson.Result) bool { + blockType := block.Get("type").String() + if blockType == "text" || blockType == "" { + addIfNotEmpty(segments, block.Get("text").String()) + } + // Also handle plain string blocks + if block.Type == gjson.String { + addIfNotEmpty(segments, block.String()) + } + return true + }) + } +} + +// collectClaudeMessages extracts text from Claude's messages array. +func collectClaudeMessages(messages gjson.Result, segments *[]string) { + if !messages.Exists() || !messages.IsArray() { + return + } + messages.ForEach(func(_, message gjson.Result) bool { + addIfNotEmpty(segments, message.Get("role").String()) + collectClaudeContent(message.Get("content"), segments) + return true + }) +} + +// collectClaudeContent extracts text from Claude's content field. +// Content can be a string or an array of content blocks. +// For images, estimates token count based on dimensions when available. +func collectClaudeContent(content gjson.Result, segments *[]string) { + if !content.Exists() { + return + } + if content.Type == gjson.String { + addIfNotEmpty(segments, content.String()) + return + } + if content.IsArray() { + content.ForEach(func(_, part gjson.Result) bool { + partType := part.Get("type").String() + switch partType { + case "text": + addIfNotEmpty(segments, part.Get("text").String()) + case "image": + // Estimate image tokens based on dimensions if available + source := part.Get("source") + if source.Exists() { + width := source.Get("width").Float() + height := source.Get("height").Float() + if width > 0 && height > 0 { + tokens := estimateImageTokens(width, height) + addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens)) + } else { + // No dimensions available, use default estimate + addIfNotEmpty(segments, "[IMAGE:1000 tokens]") + } + } else { + // No source info, use default estimate + addIfNotEmpty(segments, "[IMAGE:1000 tokens]") + } + case "tool_use": + addIfNotEmpty(segments, part.Get("id").String()) + addIfNotEmpty(segments, part.Get("name").String()) + if input := part.Get("input"); input.Exists() { + addIfNotEmpty(segments, input.Raw) + } + case "tool_result": + addIfNotEmpty(segments, part.Get("tool_use_id").String()) + collectClaudeContent(part.Get("content"), segments) + case "thinking": + addIfNotEmpty(segments, part.Get("thinking").String()) + default: + // For unknown types, try to extract any text content + if part.Type == gjson.String { + addIfNotEmpty(segments, part.String()) + } else if part.Type == gjson.JSON { + addIfNotEmpty(segments, part.Raw) + } + } + return true + }) + } +} + +// collectClaudeTools extracts text from Claude's tools array. +func collectClaudeTools(tools gjson.Result, segments *[]string) { + if !tools.Exists() || !tools.IsArray() { + return + } + tools.ForEach(func(_, tool gjson.Result) bool { + addIfNotEmpty(segments, tool.Get("name").String()) + addIfNotEmpty(segments, tool.Get("description").String()) + if inputSchema := tool.Get("input_schema"); inputSchema.Exists() { + addIfNotEmpty(segments, inputSchema.Raw) + } + return true + }) } // buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators. diff --git a/internal/translator/claude/openai/chat-completions/claude_openai_response.go b/internal/translator/claude/openai/chat-completions/claude_openai_response.go index f8fd4018..5b0238bf 100644 --- a/internal/translator/claude/openai/chat-completions/claude_openai_response.go +++ b/internal/translator/claude/openai/chat-completions/claude_openai_response.go @@ -50,6 +50,10 @@ type ToolCallAccumulator struct { // Returns: // - []string: A slice of strings, each containing an OpenAI-compatible JSON response func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { + var localParam any + if param == nil { + param = &localParam + } if *param == nil { *param = &ConvertAnthropicResponseToOpenAIParams{ CreatedAt: 0, diff --git a/internal/translator/init.go b/internal/translator/init.go index 084ea7ac..0754db03 100644 --- a/internal/translator/init.go +++ b/internal/translator/init.go @@ -33,4 +33,7 @@ import ( _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/chat-completions" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/responses" + + _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude" + _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai" ) diff --git a/internal/translator/kiro/claude/init.go b/internal/translator/kiro/claude/init.go new file mode 100644 index 00000000..1685d195 --- /dev/null +++ b/internal/translator/kiro/claude/init.go @@ -0,0 +1,20 @@ +// Package claude provides translation between Kiro and Claude formats. +package claude + +import ( + . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator" +) + +func init() { + translator.Register( + Claude, + Kiro, + ConvertClaudeRequestToKiro, + interfaces.TranslateResponse{ + Stream: ConvertKiroStreamToClaude, + NonStream: ConvertKiroNonStreamToClaude, + }, + ) +} diff --git a/internal/translator/kiro/claude/kiro_claude.go b/internal/translator/kiro/claude/kiro_claude.go new file mode 100644 index 00000000..752a00d9 --- /dev/null +++ b/internal/translator/kiro/claude/kiro_claude.go @@ -0,0 +1,21 @@ +// Package claude provides translation between Kiro and Claude formats. +// Since Kiro executor generates Claude-compatible SSE format internally (with event: prefix), +// translations are pass-through for streaming, but responses need proper formatting. +package claude + +import ( + "context" +) + +// ConvertKiroStreamToClaude converts Kiro streaming response to Claude format. +// Kiro executor already generates complete SSE format with "event:" prefix, +// so this is a simple pass-through. +func ConvertKiroStreamToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string { + return []string{string(rawResponse)} +} + +// ConvertKiroNonStreamToClaude converts Kiro non-streaming response to Claude format. +// The response is already in Claude format, so this is a pass-through. +func ConvertKiroNonStreamToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string { + return string(rawResponse) +} diff --git a/internal/translator/kiro/claude/kiro_claude_request.go b/internal/translator/kiro/claude/kiro_claude_request.go new file mode 100644 index 00000000..052d671c --- /dev/null +++ b/internal/translator/kiro/claude/kiro_claude_request.go @@ -0,0 +1,774 @@ +// Package claude provides request translation functionality for Claude API to Kiro format. +// It handles parsing and transforming Claude API requests into the Kiro/Amazon Q API format, +// extracting model information, system instructions, message contents, and tool declarations. +package claude + +import ( + "encoding/json" + "fmt" + "strings" + "time" + "unicode/utf8" + + "github.com/google/uuid" + kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" +) + + +// Kiro API request structs - field order determines JSON key order + +// KiroPayload is the top-level request structure for Kiro API +type KiroPayload struct { + ConversationState KiroConversationState `json:"conversationState"` + ProfileArn string `json:"profileArn,omitempty"` + InferenceConfig *KiroInferenceConfig `json:"inferenceConfig,omitempty"` +} + +// KiroInferenceConfig contains inference parameters for the Kiro API. +type KiroInferenceConfig struct { + MaxTokens int `json:"maxTokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` +} + +// KiroConversationState holds the conversation context +type KiroConversationState struct { + ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" - must be first field + ConversationID string `json:"conversationId"` + CurrentMessage KiroCurrentMessage `json:"currentMessage"` + History []KiroHistoryMessage `json:"history,omitempty"` +} + +// KiroCurrentMessage wraps the current user message +type KiroCurrentMessage struct { + UserInputMessage KiroUserInputMessage `json:"userInputMessage"` +} + +// KiroHistoryMessage represents a message in the conversation history +type KiroHistoryMessage struct { + UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"` + AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"` +} + +// KiroImage represents an image in Kiro API format +type KiroImage struct { + Format string `json:"format"` + Source KiroImageSource `json:"source"` +} + +// KiroImageSource contains the image data +type KiroImageSource struct { + Bytes string `json:"bytes"` // base64 encoded image data +} + +// KiroUserInputMessage represents a user message +type KiroUserInputMessage struct { + Content string `json:"content"` + ModelID string `json:"modelId"` + Origin string `json:"origin"` + Images []KiroImage `json:"images,omitempty"` + UserInputMessageContext *KiroUserInputMessageContext `json:"userInputMessageContext,omitempty"` +} + +// KiroUserInputMessageContext contains tool-related context +type KiroUserInputMessageContext struct { + ToolResults []KiroToolResult `json:"toolResults,omitempty"` + Tools []KiroToolWrapper `json:"tools,omitempty"` +} + +// KiroToolResult represents a tool execution result +type KiroToolResult struct { + Content []KiroTextContent `json:"content"` + Status string `json:"status"` + ToolUseID string `json:"toolUseId"` +} + +// KiroTextContent represents text content +type KiroTextContent struct { + Text string `json:"text"` +} + +// KiroToolWrapper wraps a tool specification +type KiroToolWrapper struct { + ToolSpecification KiroToolSpecification `json:"toolSpecification"` +} + +// KiroToolSpecification defines a tool's schema +type KiroToolSpecification struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema KiroInputSchema `json:"inputSchema"` +} + +// KiroInputSchema wraps the JSON schema for tool input +type KiroInputSchema struct { + JSON interface{} `json:"json"` +} + +// KiroAssistantResponseMessage represents an assistant message +type KiroAssistantResponseMessage struct { + Content string `json:"content"` + ToolUses []KiroToolUse `json:"toolUses,omitempty"` +} + +// KiroToolUse represents a tool invocation by the assistant +type KiroToolUse struct { + ToolUseID string `json:"toolUseId"` + Name string `json:"name"` + Input map[string]interface{} `json:"input"` +} + +// ConvertClaudeRequestToKiro converts a Claude API request to Kiro format. +// This is the main entry point for request translation. +func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte { + // For Kiro, we pass through the Claude format since buildKiroPayload + // expects Claude format and does the conversion internally. + // The actual conversion happens in the executor when building the HTTP request. + return inputRawJSON +} + +// BuildKiroPayload constructs the Kiro API request payload from Claude format. +// Supports tool calling - tools are passed via userInputMessageContext. +// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE. +// isAgentic parameter enables chunked write optimization prompt for -agentic model variants. +// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode). +// Supports thinking mode - when Claude API thinking parameter is present, injects thinkingHint. +// Returns the payload and a boolean indicating whether thinking mode was injected. +func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) ([]byte, bool) { + // Extract max_tokens for potential use in inferenceConfig + // Handle -1 as "use maximum" (Kiro max output is ~32000 tokens) + const kiroMaxOutputTokens = 32000 + var maxTokens int64 + if mt := gjson.GetBytes(claudeBody, "max_tokens"); mt.Exists() { + maxTokens = mt.Int() + if maxTokens == -1 { + maxTokens = kiroMaxOutputTokens + log.Debugf("kiro: max_tokens=-1 converted to %d", kiroMaxOutputTokens) + } + } + + // Extract temperature if specified + var temperature float64 + var hasTemperature bool + if temp := gjson.GetBytes(claudeBody, "temperature"); temp.Exists() { + temperature = temp.Float() + hasTemperature = true + } + + // Extract top_p if specified + var topP float64 + var hasTopP bool + if tp := gjson.GetBytes(claudeBody, "top_p"); tp.Exists() { + topP = tp.Float() + hasTopP = true + log.Debugf("kiro: extracted top_p: %.2f", topP) + } + + // Normalize origin value for Kiro API compatibility + origin = normalizeOrigin(origin) + log.Debugf("kiro: normalized origin value: %s", origin) + + messages := gjson.GetBytes(claudeBody, "messages") + + // For chat-only mode, don't include tools + var tools gjson.Result + if !isChatOnly { + tools = gjson.GetBytes(claudeBody, "tools") + } + + // Extract system prompt + systemPrompt := extractSystemPrompt(claudeBody) + + // Check for thinking mode using the comprehensive IsThinkingEnabled function + // This supports Claude API format, OpenAI reasoning_effort, and AMP/Cursor format + thinkingEnabled := IsThinkingEnabled(claudeBody) + _, budgetTokens := checkThinkingMode(claudeBody) // Get budget tokens from Claude format if available + if budgetTokens <= 0 { + // Calculate budgetTokens based on max_tokens if available + // Use 50% of max_tokens for thinking, with min 8000 and max 24000 + if maxTokens > 0 { + budgetTokens = maxTokens / 2 + if budgetTokens < 8000 { + budgetTokens = 8000 + } + if budgetTokens > 24000 { + budgetTokens = 24000 + } + log.Debugf("kiro: budgetTokens calculated from max_tokens: %d (max_tokens=%d)", budgetTokens, maxTokens) + } else { + budgetTokens = 16000 // Default budget tokens + } + } + + // Inject timestamp context + timestamp := time.Now().Format("2006-01-02 15:04:05 MST") + timestampContext := fmt.Sprintf("[Context: Current time is %s]", timestamp) + if systemPrompt != "" { + systemPrompt = timestampContext + "\n\n" + systemPrompt + } else { + systemPrompt = timestampContext + } + log.Debugf("kiro: injected timestamp context: %s", timestamp) + + // Inject agentic optimization prompt for -agentic model variants + if isAgentic { + if systemPrompt != "" { + systemPrompt += "\n" + } + systemPrompt += kirocommon.KiroAgenticSystemPrompt + } + + // Handle tool_choice parameter - Kiro doesn't support it natively, so we inject system prompt hints + // Claude tool_choice values: {"type": "auto/any/tool", "name": "..."} + toolChoiceHint := extractClaudeToolChoiceHint(claudeBody) + if toolChoiceHint != "" { + if systemPrompt != "" { + systemPrompt += "\n" + } + systemPrompt += toolChoiceHint + log.Debugf("kiro: injected tool_choice hint into system prompt") + } + + // Inject thinking hint when thinking mode is enabled + if thinkingEnabled { + if systemPrompt != "" { + systemPrompt += "\n" + } + dynamicThinkingHint := fmt.Sprintf("interleaved%d", budgetTokens) + systemPrompt += dynamicThinkingHint + log.Debugf("kiro: injected dynamic thinking hint into system prompt, max_thinking_length: %d", budgetTokens) + } + + // Convert Claude tools to Kiro format + kiroTools := convertClaudeToolsToKiro(tools) + + // Process messages and build history + history, currentUserMsg, currentToolResults := processMessages(messages, modelID, origin) + + // Build content with system prompt + if currentUserMsg != nil { + currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, systemPrompt, currentToolResults) + + // Deduplicate currentToolResults + currentToolResults = deduplicateToolResults(currentToolResults) + + // Build userInputMessageContext with tools and tool results + if len(kiroTools) > 0 || len(currentToolResults) > 0 { + currentUserMsg.UserInputMessageContext = &KiroUserInputMessageContext{ + Tools: kiroTools, + ToolResults: currentToolResults, + } + } + } + + // Build payload + var currentMessage KiroCurrentMessage + if currentUserMsg != nil { + currentMessage = KiroCurrentMessage{UserInputMessage: *currentUserMsg} + } else { + fallbackContent := "" + if systemPrompt != "" { + fallbackContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n" + } + currentMessage = KiroCurrentMessage{UserInputMessage: KiroUserInputMessage{ + Content: fallbackContent, + ModelID: modelID, + Origin: origin, + }} + } + + // Build inferenceConfig if we have any inference parameters + var inferenceConfig *KiroInferenceConfig + if maxTokens > 0 || hasTemperature || hasTopP { + inferenceConfig = &KiroInferenceConfig{} + if maxTokens > 0 { + inferenceConfig.MaxTokens = int(maxTokens) + } + if hasTemperature { + inferenceConfig.Temperature = temperature + } + if hasTopP { + inferenceConfig.TopP = topP + } + } + + payload := KiroPayload{ + ConversationState: KiroConversationState{ + ChatTriggerType: "MANUAL", + ConversationID: uuid.New().String(), + CurrentMessage: currentMessage, + History: history, + }, + ProfileArn: profileArn, + InferenceConfig: inferenceConfig, + } + + result, err := json.Marshal(payload) + if err != nil { + log.Debugf("kiro: failed to marshal payload: %v", err) + return nil, false + } + + return result, thinkingEnabled +} + +// normalizeOrigin normalizes origin value for Kiro API compatibility +func normalizeOrigin(origin string) string { + switch origin { + case "KIRO_CLI": + return "CLI" + case "KIRO_AI_EDITOR": + return "AI_EDITOR" + case "AMAZON_Q": + return "CLI" + case "KIRO_IDE": + return "AI_EDITOR" + default: + return origin + } +} + +// extractSystemPrompt extracts system prompt from Claude request +func extractSystemPrompt(claudeBody []byte) string { + systemField := gjson.GetBytes(claudeBody, "system") + if systemField.IsArray() { + var sb strings.Builder + for _, block := range systemField.Array() { + if block.Get("type").String() == "text" { + sb.WriteString(block.Get("text").String()) + } else if block.Type == gjson.String { + sb.WriteString(block.String()) + } + } + return sb.String() + } + return systemField.String() +} + +// checkThinkingMode checks if thinking mode is enabled in the Claude request +func checkThinkingMode(claudeBody []byte) (bool, int64) { + thinkingEnabled := false + var budgetTokens int64 = 16000 + + thinkingField := gjson.GetBytes(claudeBody, "thinking") + if thinkingField.Exists() { + thinkingType := thinkingField.Get("type").String() + if thinkingType == "enabled" { + thinkingEnabled = true + if bt := thinkingField.Get("budget_tokens"); bt.Exists() { + budgetTokens = bt.Int() + if budgetTokens <= 0 { + thinkingEnabled = false + log.Debugf("kiro: thinking mode disabled via budget_tokens <= 0") + } + } + if thinkingEnabled { + log.Debugf("kiro: thinking mode enabled via Claude API parameter, budget_tokens: %d", budgetTokens) + } + } + } + + return thinkingEnabled, budgetTokens +} + +// IsThinkingEnabled is a public wrapper to check if thinking mode is enabled. +// This is used by the executor to determine whether to parse tags in responses. +// When thinking is NOT enabled in the request, tags in responses should be +// treated as regular text content, not as thinking blocks. +// +// Supports multiple formats: +// - Claude API format: thinking.type = "enabled" +// - OpenAI format: reasoning_effort parameter +// - AMP/Cursor format: interleaved in system prompt +func IsThinkingEnabled(body []byte) bool { + // Check Claude API format first (thinking.type = "enabled") + enabled, _ := checkThinkingMode(body) + if enabled { + log.Debugf("kiro: IsThinkingEnabled returning true (Claude API format)") + return true + } + + // Check OpenAI format: reasoning_effort parameter + // Valid values: "low", "medium", "high", "auto" (not "none") + reasoningEffort := gjson.GetBytes(body, "reasoning_effort") + if reasoningEffort.Exists() { + effort := reasoningEffort.String() + if effort != "" && effort != "none" { + log.Debugf("kiro: thinking mode enabled via OpenAI reasoning_effort: %s", effort) + return true + } + } + + // Check AMP/Cursor format: interleaved in system prompt + // This is how AMP client passes thinking configuration + bodyStr := string(body) + if strings.Contains(bodyStr, "") && strings.Contains(bodyStr, "") { + // Extract thinking mode value + startTag := "" + endTag := "" + startIdx := strings.Index(bodyStr, startTag) + if startIdx >= 0 { + startIdx += len(startTag) + endIdx := strings.Index(bodyStr[startIdx:], endTag) + if endIdx >= 0 { + thinkingMode := bodyStr[startIdx : startIdx+endIdx] + if thinkingMode == "interleaved" || thinkingMode == "enabled" { + log.Debugf("kiro: thinking mode enabled via AMP/Cursor format: %s", thinkingMode) + return true + } + } + } + } + + // Check OpenAI format: max_completion_tokens with reasoning (o1-style) + // Some clients use this to indicate reasoning mode + if gjson.GetBytes(body, "max_completion_tokens").Exists() { + // If max_completion_tokens is set, check if model name suggests reasoning + model := gjson.GetBytes(body, "model").String() + if strings.Contains(strings.ToLower(model), "thinking") || + strings.Contains(strings.ToLower(model), "reason") { + log.Debugf("kiro: thinking mode enabled via model name hint: %s", model) + return true + } + } + + log.Debugf("kiro: IsThinkingEnabled returning false (no thinking mode detected)") + return false +} + +// shortenToolNameIfNeeded shortens tool names that exceed 64 characters. +// MCP tools often have long names like "mcp__server-name__tool-name". +// This preserves the "mcp__" prefix and last segment when possible. +func shortenToolNameIfNeeded(name string) string { + const limit = 64 + if len(name) <= limit { + return name + } + // For MCP tools, try to preserve prefix and last segment + if strings.HasPrefix(name, "mcp__") { + idx := strings.LastIndex(name, "__") + if idx > 0 { + cand := "mcp__" + name[idx+2:] + if len(cand) > limit { + return cand[:limit] + } + return cand + } + } + return name[:limit] +} + +// convertClaudeToolsToKiro converts Claude tools to Kiro format +func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper { + var kiroTools []KiroToolWrapper + if !tools.IsArray() { + return kiroTools + } + + for _, tool := range tools.Array() { + name := tool.Get("name").String() + description := tool.Get("description").String() + inputSchema := tool.Get("input_schema").Value() + + // Shorten tool name if it exceeds 64 characters (common with MCP tools) + originalName := name + name = shortenToolNameIfNeeded(name) + if name != originalName { + log.Debugf("kiro: shortened tool name from '%s' to '%s'", originalName, name) + } + + // CRITICAL FIX: Kiro API requires non-empty description + if strings.TrimSpace(description) == "" { + description = fmt.Sprintf("Tool: %s", name) + log.Debugf("kiro: tool '%s' has empty description, using default: %s", name, description) + } + + // Truncate long descriptions + if len(description) > kirocommon.KiroMaxToolDescLen { + truncLen := kirocommon.KiroMaxToolDescLen - 30 + for truncLen > 0 && !utf8.RuneStart(description[truncLen]) { + truncLen-- + } + description = description[:truncLen] + "... (description truncated)" + } + + kiroTools = append(kiroTools, KiroToolWrapper{ + ToolSpecification: KiroToolSpecification{ + Name: name, + Description: description, + InputSchema: KiroInputSchema{JSON: inputSchema}, + }, + }) + } + + return kiroTools +} + +// processMessages processes Claude messages and builds Kiro history +func processMessages(messages gjson.Result, modelID, origin string) ([]KiroHistoryMessage, *KiroUserInputMessage, []KiroToolResult) { + var history []KiroHistoryMessage + var currentUserMsg *KiroUserInputMessage + var currentToolResults []KiroToolResult + + // Merge adjacent messages with the same role + messagesArray := kirocommon.MergeAdjacentMessages(messages.Array()) + for i, msg := range messagesArray { + role := msg.Get("role").String() + isLastMessage := i == len(messagesArray)-1 + + if role == "user" { + userMsg, toolResults := BuildUserMessageStruct(msg, modelID, origin) + if isLastMessage { + currentUserMsg = &userMsg + currentToolResults = toolResults + } else { + // CRITICAL: Kiro API requires content to be non-empty for history messages too + if strings.TrimSpace(userMsg.Content) == "" { + if len(toolResults) > 0 { + userMsg.Content = "Tool results provided." + } else { + userMsg.Content = "Continue" + } + } + // For history messages, embed tool results in context + if len(toolResults) > 0 { + userMsg.UserInputMessageContext = &KiroUserInputMessageContext{ + ToolResults: toolResults, + } + } + history = append(history, KiroHistoryMessage{ + UserInputMessage: &userMsg, + }) + } + } else if role == "assistant" { + assistantMsg := BuildAssistantMessageStruct(msg) + if isLastMessage { + history = append(history, KiroHistoryMessage{ + AssistantResponseMessage: &assistantMsg, + }) + // Create a "Continue" user message as currentMessage + currentUserMsg = &KiroUserInputMessage{ + Content: "Continue", + ModelID: modelID, + Origin: origin, + } + } else { + history = append(history, KiroHistoryMessage{ + AssistantResponseMessage: &assistantMsg, + }) + } + } + } + + return history, currentUserMsg, currentToolResults +} + +// buildFinalContent builds the final content with system prompt +func buildFinalContent(content, systemPrompt string, toolResults []KiroToolResult) string { + var contentBuilder strings.Builder + + if systemPrompt != "" { + contentBuilder.WriteString("--- SYSTEM PROMPT ---\n") + contentBuilder.WriteString(systemPrompt) + contentBuilder.WriteString("\n--- END SYSTEM PROMPT ---\n\n") + } + + contentBuilder.WriteString(content) + finalContent := contentBuilder.String() + + // CRITICAL: Kiro API requires content to be non-empty + if strings.TrimSpace(finalContent) == "" { + if len(toolResults) > 0 { + finalContent = "Tool results provided." + } else { + finalContent = "Continue" + } + log.Debugf("kiro: content was empty, using default: %s", finalContent) + } + + return finalContent +} + +// deduplicateToolResults removes duplicate tool results +func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult { + if len(toolResults) == 0 { + return toolResults + } + + seenIDs := make(map[string]bool) + unique := make([]KiroToolResult, 0, len(toolResults)) + for _, tr := range toolResults { + if !seenIDs[tr.ToolUseID] { + seenIDs[tr.ToolUseID] = true + unique = append(unique, tr) + } else { + log.Debugf("kiro: skipping duplicate toolResult in currentMessage: %s", tr.ToolUseID) + } + } + return unique +} + +// extractClaudeToolChoiceHint extracts tool_choice from Claude request and returns a system prompt hint. +// Claude tool_choice values: +// - {"type": "auto"}: Model decides (default, no hint needed) +// - {"type": "any"}: Must use at least one tool +// - {"type": "tool", "name": "..."}: Must use specific tool +func extractClaudeToolChoiceHint(claudeBody []byte) string { + toolChoice := gjson.GetBytes(claudeBody, "tool_choice") + if !toolChoice.Exists() { + return "" + } + + toolChoiceType := toolChoice.Get("type").String() + switch toolChoiceType { + case "any": + return "[INSTRUCTION: You MUST use at least one of the available tools to respond. Do not respond with text only - always make a tool call.]" + case "tool": + toolName := toolChoice.Get("name").String() + if toolName != "" { + return fmt.Sprintf("[INSTRUCTION: You MUST use the tool named '%s' to respond. Do not use any other tool or respond with text only.]", toolName) + } + case "auto": + // Default behavior, no hint needed + return "" + } + + return "" +} + +// BuildUserMessageStruct builds a user message and extracts tool results +func BuildUserMessageStruct(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) { + content := msg.Get("content") + var contentBuilder strings.Builder + var toolResults []KiroToolResult + var images []KiroImage + + // Track seen toolUseIds to deduplicate + seenToolUseIDs := make(map[string]bool) + + if content.IsArray() { + for _, part := range content.Array() { + partType := part.Get("type").String() + switch partType { + case "text": + contentBuilder.WriteString(part.Get("text").String()) + case "image": + mediaType := part.Get("source.media_type").String() + data := part.Get("source.data").String() + + format := "" + if idx := strings.LastIndex(mediaType, "/"); idx != -1 { + format = mediaType[idx+1:] + } + + if format != "" && data != "" { + images = append(images, KiroImage{ + Format: format, + Source: KiroImageSource{ + Bytes: data, + }, + }) + } + case "tool_result": + toolUseID := part.Get("tool_use_id").String() + + // Skip duplicate toolUseIds + if seenToolUseIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate tool_result with toolUseId: %s", toolUseID) + continue + } + seenToolUseIDs[toolUseID] = true + + isError := part.Get("is_error").Bool() + resultContent := part.Get("content") + + var textContents []KiroTextContent + if resultContent.IsArray() { + for _, item := range resultContent.Array() { + if item.Get("type").String() == "text" { + textContents = append(textContents, KiroTextContent{Text: item.Get("text").String()}) + } else if item.Type == gjson.String { + textContents = append(textContents, KiroTextContent{Text: item.String()}) + } + } + } else if resultContent.Type == gjson.String { + textContents = append(textContents, KiroTextContent{Text: resultContent.String()}) + } + + if len(textContents) == 0 { + textContents = append(textContents, KiroTextContent{Text: "Tool use was cancelled by the user"}) + } + + status := "success" + if isError { + status = "error" + } + + toolResults = append(toolResults, KiroToolResult{ + ToolUseID: toolUseID, + Content: textContents, + Status: status, + }) + } + } + } else { + contentBuilder.WriteString(content.String()) + } + + userMsg := KiroUserInputMessage{ + Content: contentBuilder.String(), + ModelID: modelID, + Origin: origin, + } + + if len(images) > 0 { + userMsg.Images = images + } + + return userMsg, toolResults +} + +// BuildAssistantMessageStruct builds an assistant message with tool uses +func BuildAssistantMessageStruct(msg gjson.Result) KiroAssistantResponseMessage { + content := msg.Get("content") + var contentBuilder strings.Builder + var toolUses []KiroToolUse + + if content.IsArray() { + for _, part := range content.Array() { + partType := part.Get("type").String() + switch partType { + case "text": + contentBuilder.WriteString(part.Get("text").String()) + case "tool_use": + toolUseID := part.Get("id").String() + toolName := part.Get("name").String() + toolInput := part.Get("input") + + var inputMap map[string]interface{} + if toolInput.IsObject() { + inputMap = make(map[string]interface{}) + toolInput.ForEach(func(key, value gjson.Result) bool { + inputMap[key.String()] = value.Value() + return true + }) + } + + toolUses = append(toolUses, KiroToolUse{ + ToolUseID: toolUseID, + Name: toolName, + Input: inputMap, + }) + } + } + } else { + contentBuilder.WriteString(content.String()) + } + + return KiroAssistantResponseMessage{ + Content: contentBuilder.String(), + ToolUses: toolUses, + } +} \ No newline at end of file diff --git a/internal/translator/kiro/claude/kiro_claude_response.go b/internal/translator/kiro/claude/kiro_claude_response.go new file mode 100644 index 00000000..49ebf79e --- /dev/null +++ b/internal/translator/kiro/claude/kiro_claude_response.go @@ -0,0 +1,184 @@ +// Package claude provides response translation functionality for Kiro API to Claude format. +// This package handles the conversion of Kiro API responses into Claude-compatible format, +// including support for thinking blocks and tool use. +package claude + +import ( + "encoding/json" + "strings" + + "github.com/google/uuid" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" + log "github.com/sirupsen/logrus" + + kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common" +) + +// Local references to kirocommon constants for thinking block parsing +var ( + thinkingStartTag = kirocommon.ThinkingStartTag + thinkingEndTag = kirocommon.ThinkingEndTag +) + +// BuildClaudeResponse constructs a Claude-compatible response. +// Supports tool_use blocks when tools are present in the response. +// Supports thinking blocks - parses tags and converts to Claude thinking content blocks. +// stopReason is passed from upstream; fallback logic applied if empty. +func BuildClaudeResponse(content string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte { + var contentBlocks []map[string]interface{} + + // Extract thinking blocks and text from content + if content != "" { + blocks := ExtractThinkingFromContent(content) + contentBlocks = append(contentBlocks, blocks...) + + // Log if thinking blocks were extracted + for _, block := range blocks { + if block["type"] == "thinking" { + thinkingContent := block["thinking"].(string) + log.Infof("kiro: buildClaudeResponse extracted thinking block (len: %d)", len(thinkingContent)) + } + } + } + + // Add tool_use blocks + for _, toolUse := range toolUses { + contentBlocks = append(contentBlocks, map[string]interface{}{ + "type": "tool_use", + "id": toolUse.ToolUseID, + "name": toolUse.Name, + "input": toolUse.Input, + }) + } + + // Ensure at least one content block (Claude API requires non-empty content) + if len(contentBlocks) == 0 { + contentBlocks = append(contentBlocks, map[string]interface{}{ + "type": "text", + "text": "", + }) + } + + // Use upstream stopReason; apply fallback logic if not provided + if stopReason == "" { + stopReason = "end_turn" + if len(toolUses) > 0 { + stopReason = "tool_use" + } + log.Debugf("kiro: buildClaudeResponse using fallback stop_reason: %s", stopReason) + } + + // Log warning if response was truncated due to max_tokens + if stopReason == "max_tokens" { + log.Warnf("kiro: response truncated due to max_tokens limit (buildClaudeResponse)") + } + + response := map[string]interface{}{ + "id": "msg_" + uuid.New().String()[:24], + "type": "message", + "role": "assistant", + "model": model, + "content": contentBlocks, + "stop_reason": stopReason, + "usage": map[string]interface{}{ + "input_tokens": usageInfo.InputTokens, + "output_tokens": usageInfo.OutputTokens, + }, + } + result, _ := json.Marshal(response) + return result +} + +// ExtractThinkingFromContent parses content to extract thinking blocks and text. +// Returns a list of content blocks in the order they appear in the content. +// Handles interleaved thinking and text blocks correctly. +func ExtractThinkingFromContent(content string) []map[string]interface{} { + var blocks []map[string]interface{} + + if content == "" { + return blocks + } + + // Check if content contains thinking tags at all + if !strings.Contains(content, thinkingStartTag) { + // No thinking tags, return as plain text + return []map[string]interface{}{ + { + "type": "text", + "text": content, + }, + } + } + + log.Debugf("kiro: extractThinkingFromContent - found thinking tags in content (len: %d)", len(content)) + + remaining := content + + for len(remaining) > 0 { + // Look for tag + startIdx := strings.Index(remaining, thinkingStartTag) + + if startIdx == -1 { + // No more thinking tags, add remaining as text + if strings.TrimSpace(remaining) != "" { + blocks = append(blocks, map[string]interface{}{ + "type": "text", + "text": remaining, + }) + } + break + } + + // Add text before thinking tag (if any meaningful content) + if startIdx > 0 { + textBefore := remaining[:startIdx] + if strings.TrimSpace(textBefore) != "" { + blocks = append(blocks, map[string]interface{}{ + "type": "text", + "text": textBefore, + }) + } + } + + // Move past the opening tag + remaining = remaining[startIdx+len(thinkingStartTag):] + + // Find closing tag + endIdx := strings.Index(remaining, thinkingEndTag) + + if endIdx == -1 { + // No closing tag found, treat rest as thinking content (incomplete response) + if strings.TrimSpace(remaining) != "" { + blocks = append(blocks, map[string]interface{}{ + "type": "thinking", + "thinking": remaining, + }) + log.Warnf("kiro: extractThinkingFromContent - missing closing tag") + } + break + } + + // Extract thinking content between tags + thinkContent := remaining[:endIdx] + if strings.TrimSpace(thinkContent) != "" { + blocks = append(blocks, map[string]interface{}{ + "type": "thinking", + "thinking": thinkContent, + }) + log.Debugf("kiro: extractThinkingFromContent - extracted thinking block (len: %d)", len(thinkContent)) + } + + // Move past the closing tag + remaining = remaining[endIdx+len(thinkingEndTag):] + } + + // If no blocks were created (all whitespace), return empty text block + if len(blocks) == 0 { + blocks = append(blocks, map[string]interface{}{ + "type": "text", + "text": "", + }) + } + + return blocks +} \ No newline at end of file diff --git a/internal/translator/kiro/claude/kiro_claude_stream.go b/internal/translator/kiro/claude/kiro_claude_stream.go new file mode 100644 index 00000000..6ea6e4cd --- /dev/null +++ b/internal/translator/kiro/claude/kiro_claude_stream.go @@ -0,0 +1,176 @@ +// Package claude provides streaming SSE event building for Claude format. +// This package handles the construction of Claude-compatible Server-Sent Events (SSE) +// for streaming responses from Kiro API. +package claude + +import ( + "encoding/json" + + "github.com/google/uuid" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" +) + +// BuildClaudeMessageStartEvent creates the message_start SSE event +func BuildClaudeMessageStartEvent(model string, inputTokens int64) []byte { + event := map[string]interface{}{ + "type": "message_start", + "message": map[string]interface{}{ + "id": "msg_" + uuid.New().String()[:24], + "type": "message", + "role": "assistant", + "content": []interface{}{}, + "model": model, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]interface{}{"input_tokens": inputTokens, "output_tokens": 0}, + }, + } + result, _ := json.Marshal(event) + return []byte("event: message_start\ndata: " + string(result)) +} + +// BuildClaudeContentBlockStartEvent creates a content_block_start SSE event +func BuildClaudeContentBlockStartEvent(index int, blockType, toolUseID, toolName string) []byte { + var contentBlock map[string]interface{} + switch blockType { + case "tool_use": + contentBlock = map[string]interface{}{ + "type": "tool_use", + "id": toolUseID, + "name": toolName, + "input": map[string]interface{}{}, + } + case "thinking": + contentBlock = map[string]interface{}{ + "type": "thinking", + "thinking": "", + } + default: + contentBlock = map[string]interface{}{ + "type": "text", + "text": "", + } + } + + event := map[string]interface{}{ + "type": "content_block_start", + "index": index, + "content_block": contentBlock, + } + result, _ := json.Marshal(event) + return []byte("event: content_block_start\ndata: " + string(result)) +} + +// BuildClaudeStreamEvent creates a text_delta content_block_delta SSE event +func BuildClaudeStreamEvent(contentDelta string, index int) []byte { + event := map[string]interface{}{ + "type": "content_block_delta", + "index": index, + "delta": map[string]interface{}{ + "type": "text_delta", + "text": contentDelta, + }, + } + result, _ := json.Marshal(event) + return []byte("event: content_block_delta\ndata: " + string(result)) +} + +// BuildClaudeInputJsonDeltaEvent creates an input_json_delta event for tool use streaming +func BuildClaudeInputJsonDeltaEvent(partialJSON string, index int) []byte { + event := map[string]interface{}{ + "type": "content_block_delta", + "index": index, + "delta": map[string]interface{}{ + "type": "input_json_delta", + "partial_json": partialJSON, + }, + } + result, _ := json.Marshal(event) + return []byte("event: content_block_delta\ndata: " + string(result)) +} + +// BuildClaudeContentBlockStopEvent creates a content_block_stop SSE event +func BuildClaudeContentBlockStopEvent(index int) []byte { + event := map[string]interface{}{ + "type": "content_block_stop", + "index": index, + } + result, _ := json.Marshal(event) + return []byte("event: content_block_stop\ndata: " + string(result)) +} + +// BuildClaudeMessageDeltaEvent creates the message_delta event with stop_reason and usage +func BuildClaudeMessageDeltaEvent(stopReason string, usageInfo usage.Detail) []byte { + deltaEvent := map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{ + "stop_reason": stopReason, + "stop_sequence": nil, + }, + "usage": map[string]interface{}{ + "input_tokens": usageInfo.InputTokens, + "output_tokens": usageInfo.OutputTokens, + }, + } + deltaResult, _ := json.Marshal(deltaEvent) + return []byte("event: message_delta\ndata: " + string(deltaResult)) +} + +// BuildClaudeMessageStopOnlyEvent creates only the message_stop event +func BuildClaudeMessageStopOnlyEvent() []byte { + stopEvent := map[string]interface{}{ + "type": "message_stop", + } + stopResult, _ := json.Marshal(stopEvent) + return []byte("event: message_stop\ndata: " + string(stopResult)) +} + +// BuildClaudePingEventWithUsage creates a ping event with embedded usage information. +// This is used for real-time usage estimation during streaming. +func BuildClaudePingEventWithUsage(inputTokens, outputTokens int64) []byte { + event := map[string]interface{}{ + "type": "ping", + "usage": map[string]interface{}{ + "input_tokens": inputTokens, + "output_tokens": outputTokens, + "total_tokens": inputTokens + outputTokens, + "estimated": true, + }, + } + result, _ := json.Marshal(event) + return []byte("event: ping\ndata: " + string(result)) +} + +// BuildClaudeThinkingDeltaEvent creates a thinking_delta event for Claude API compatibility. +// This is used when streaming thinking content wrapped in tags. +func BuildClaudeThinkingDeltaEvent(thinkingDelta string, index int) []byte { + event := map[string]interface{}{ + "type": "content_block_delta", + "index": index, + "delta": map[string]interface{}{ + "type": "thinking_delta", + "thinking": thinkingDelta, + }, + } + result, _ := json.Marshal(event) + return []byte("event: content_block_delta\ndata: " + string(result)) +} + +// PendingTagSuffix detects if the buffer ends with a partial prefix of the given tag. +// Returns the length of the partial match (0 if no match). +// Based on amq2api implementation for handling cross-chunk tag boundaries. +func PendingTagSuffix(buffer, tag string) int { + if buffer == "" || tag == "" { + return 0 + } + maxLen := len(buffer) + if maxLen > len(tag)-1 { + maxLen = len(tag) - 1 + } + for length := maxLen; length > 0; length-- { + if len(buffer) >= length && buffer[len(buffer)-length:] == tag[:length] { + return length + } + } + return 0 +} \ No newline at end of file diff --git a/internal/translator/kiro/claude/kiro_claude_tools.go b/internal/translator/kiro/claude/kiro_claude_tools.go new file mode 100644 index 00000000..93ede875 --- /dev/null +++ b/internal/translator/kiro/claude/kiro_claude_tools.go @@ -0,0 +1,522 @@ +// Package claude provides tool calling support for Kiro to Claude translation. +// This package handles parsing embedded tool calls, JSON repair, and deduplication. +package claude + +import ( + "encoding/json" + "regexp" + "strings" + + "github.com/google/uuid" + kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common" + log "github.com/sirupsen/logrus" +) + +// ToolUseState tracks the state of an in-progress tool use during streaming. +type ToolUseState struct { + ToolUseID string + Name string + InputBuffer strings.Builder + IsComplete bool +} + +// Pre-compiled regex patterns for performance +var ( + // embeddedToolCallPattern matches [Called tool_name with args: {...}] format + embeddedToolCallPattern = regexp.MustCompile(`\[Called\s+([A-Za-z0-9_.-]+)\s+with\s+args:\s*`) + // trailingCommaPattern matches trailing commas before closing braces/brackets + trailingCommaPattern = regexp.MustCompile(`,\s*([}\]])`) +) + +// ParseEmbeddedToolCalls extracts [Called tool_name with args: {...}] format from text. +// Kiro sometimes embeds tool calls in text content instead of using toolUseEvent. +// Returns the cleaned text (with tool calls removed) and extracted tool uses. +func ParseEmbeddedToolCalls(text string, processedIDs map[string]bool) (string, []KiroToolUse) { + if !strings.Contains(text, "[Called") { + return text, nil + } + + var toolUses []KiroToolUse + cleanText := text + + // Find all [Called markers + matches := embeddedToolCallPattern.FindAllStringSubmatchIndex(text, -1) + if len(matches) == 0 { + return text, nil + } + + // Process matches in reverse order to maintain correct indices + for i := len(matches) - 1; i >= 0; i-- { + matchStart := matches[i][0] + toolNameStart := matches[i][2] + toolNameEnd := matches[i][3] + + if toolNameStart < 0 || toolNameEnd < 0 { + continue + } + + toolName := text[toolNameStart:toolNameEnd] + + // Find the JSON object start (after "with args:") + jsonStart := matches[i][1] + if jsonStart >= len(text) { + continue + } + + // Skip whitespace to find the opening brace + for jsonStart < len(text) && (text[jsonStart] == ' ' || text[jsonStart] == '\t') { + jsonStart++ + } + + if jsonStart >= len(text) || text[jsonStart] != '{' { + continue + } + + // Find matching closing bracket + jsonEnd := findMatchingBracket(text, jsonStart) + if jsonEnd < 0 { + continue + } + + // Extract JSON and find the closing bracket of [Called ...] + jsonStr := text[jsonStart : jsonEnd+1] + + // Find the closing ] after the JSON + closingBracket := jsonEnd + 1 + for closingBracket < len(text) && text[closingBracket] != ']' { + closingBracket++ + } + if closingBracket >= len(text) { + continue + } + + // End index of the full tool call (closing ']' inclusive) + matchEnd := closingBracket + 1 + + // Repair and parse JSON + repairedJSON := RepairJSON(jsonStr) + var inputMap map[string]interface{} + if err := json.Unmarshal([]byte(repairedJSON), &inputMap); err != nil { + log.Debugf("kiro: failed to parse embedded tool call JSON: %v, raw: %s", err, jsonStr) + continue + } + + // Generate unique tool ID + toolUseID := "toolu_" + uuid.New().String()[:12] + + // Check for duplicates using name+input as key + dedupeKey := toolName + ":" + repairedJSON + if processedIDs != nil { + if processedIDs[dedupeKey] { + log.Debugf("kiro: skipping duplicate embedded tool call: %s", toolName) + // Still remove from text even if duplicate + if matchStart >= 0 && matchEnd <= len(cleanText) && matchStart <= matchEnd { + cleanText = cleanText[:matchStart] + cleanText[matchEnd:] + } + continue + } + processedIDs[dedupeKey] = true + } + + toolUses = append(toolUses, KiroToolUse{ + ToolUseID: toolUseID, + Name: toolName, + Input: inputMap, + }) + + log.Infof("kiro: extracted embedded tool call: %s (ID: %s)", toolName, toolUseID) + + // Remove from clean text (index-based removal to avoid deleting the wrong occurrence) + if matchStart >= 0 && matchEnd <= len(cleanText) && matchStart <= matchEnd { + cleanText = cleanText[:matchStart] + cleanText[matchEnd:] + } + } + + return cleanText, toolUses +} + +// findMatchingBracket finds the index of the closing brace/bracket that matches +// the opening one at startPos. Handles nested objects and strings correctly. +func findMatchingBracket(text string, startPos int) int { + if startPos >= len(text) { + return -1 + } + + openChar := text[startPos] + var closeChar byte + switch openChar { + case '{': + closeChar = '}' + case '[': + closeChar = ']' + default: + return -1 + } + + depth := 1 + inString := false + escapeNext := false + + for i := startPos + 1; i < len(text); i++ { + char := text[i] + + if escapeNext { + escapeNext = false + continue + } + + if char == '\\' && inString { + escapeNext = true + continue + } + + if char == '"' { + inString = !inString + continue + } + + if !inString { + if char == openChar { + depth++ + } else if char == closeChar { + depth-- + if depth == 0 { + return i + } + } + } + } + + return -1 +} + +// RepairJSON attempts to fix common JSON issues that may occur in tool call arguments. +// Conservative repair strategy: +// 1. First try to parse JSON directly - if valid, return as-is +// 2. Only attempt repair if parsing fails +// 3. After repair, validate the result - if still invalid, return original +func RepairJSON(jsonString string) string { + // Handle empty or invalid input + if jsonString == "" { + return "{}" + } + + str := strings.TrimSpace(jsonString) + if str == "" { + return "{}" + } + + // CONSERVATIVE STRATEGY: First try to parse directly + var testParse interface{} + if err := json.Unmarshal([]byte(str), &testParse); err == nil { + log.Debugf("kiro: repairJSON - JSON is already valid, returning unchanged") + return str + } + + log.Debugf("kiro: repairJSON - JSON parse failed, attempting repair") + originalStr := str + + // First, escape unescaped newlines/tabs within JSON string values + str = escapeNewlinesInStrings(str) + // Remove trailing commas before closing braces/brackets + str = trailingCommaPattern.ReplaceAllString(str, "$1") + + // Calculate bracket balance + braceCount := 0 + bracketCount := 0 + inString := false + escape := false + lastValidIndex := -1 + + for i := 0; i < len(str); i++ { + char := str[i] + + if escape { + escape = false + continue + } + + if char == '\\' { + escape = true + continue + } + + if char == '"' { + inString = !inString + continue + } + + if inString { + continue + } + + switch char { + case '{': + braceCount++ + case '}': + braceCount-- + case '[': + bracketCount++ + case ']': + bracketCount-- + } + + if braceCount >= 0 && bracketCount >= 0 { + lastValidIndex = i + } + } + + // If brackets are unbalanced, try to repair + if braceCount > 0 || bracketCount > 0 { + if lastValidIndex > 0 && lastValidIndex < len(str)-1 { + truncated := str[:lastValidIndex+1] + // Recount brackets after truncation + braceCount = 0 + bracketCount = 0 + inString = false + escape = false + for i := 0; i < len(truncated); i++ { + char := truncated[i] + if escape { + escape = false + continue + } + if char == '\\' { + escape = true + continue + } + if char == '"' { + inString = !inString + continue + } + if inString { + continue + } + switch char { + case '{': + braceCount++ + case '}': + braceCount-- + case '[': + bracketCount++ + case ']': + bracketCount-- + } + } + str = truncated + } + + // Add missing closing brackets + for braceCount > 0 { + str += "}" + braceCount-- + } + for bracketCount > 0 { + str += "]" + bracketCount-- + } + } + + // Validate repaired JSON + if err := json.Unmarshal([]byte(str), &testParse); err != nil { + log.Warnf("kiro: repairJSON - repair failed to produce valid JSON, returning original") + return originalStr + } + + log.Debugf("kiro: repairJSON - successfully repaired JSON") + return str +} + +// escapeNewlinesInStrings escapes literal newlines, tabs, and other control characters +// that appear inside JSON string values. +func escapeNewlinesInStrings(raw string) string { + var result strings.Builder + result.Grow(len(raw) + 100) + + inString := false + escaped := false + + for i := 0; i < len(raw); i++ { + c := raw[i] + + if escaped { + result.WriteByte(c) + escaped = false + continue + } + + if c == '\\' && inString { + result.WriteByte(c) + escaped = true + continue + } + + if c == '"' { + inString = !inString + result.WriteByte(c) + continue + } + + if inString { + switch c { + case '\n': + result.WriteString("\\n") + case '\r': + result.WriteString("\\r") + case '\t': + result.WriteString("\\t") + default: + result.WriteByte(c) + } + } else { + result.WriteByte(c) + } + } + + return result.String() +} + +// ProcessToolUseEvent handles a toolUseEvent from the Kiro stream. +// It accumulates input fragments and emits tool_use blocks when complete. +// Returns events to emit and updated state. +func ProcessToolUseEvent(event map[string]interface{}, currentToolUse *ToolUseState, processedIDs map[string]bool) ([]KiroToolUse, *ToolUseState) { + var toolUses []KiroToolUse + + // Extract from nested toolUseEvent or direct format + tu := event + if nested, ok := event["toolUseEvent"].(map[string]interface{}); ok { + tu = nested + } + + toolUseID := kirocommon.GetString(tu, "toolUseId") + toolName := kirocommon.GetString(tu, "name") + isStop := false + if stop, ok := tu["stop"].(bool); ok { + isStop = stop + } + + // Get input - can be string (fragment) or object (complete) + var inputFragment string + var inputMap map[string]interface{} + + if inputRaw, ok := tu["input"]; ok { + switch v := inputRaw.(type) { + case string: + inputFragment = v + case map[string]interface{}: + inputMap = v + } + } + + // New tool use starting + if toolUseID != "" && toolName != "" { + if currentToolUse != nil && currentToolUse.ToolUseID != toolUseID { + log.Warnf("kiro: interleaved tool use detected - new ID %s arrived while %s in progress, completing previous", + toolUseID, currentToolUse.ToolUseID) + if !processedIDs[currentToolUse.ToolUseID] { + incomplete := KiroToolUse{ + ToolUseID: currentToolUse.ToolUseID, + Name: currentToolUse.Name, + } + if currentToolUse.InputBuffer.Len() > 0 { + raw := currentToolUse.InputBuffer.String() + repaired := RepairJSON(raw) + + var input map[string]interface{} + if err := json.Unmarshal([]byte(repaired), &input); err != nil { + log.Warnf("kiro: failed to parse interleaved tool input: %v, raw: %s", err, raw) + input = make(map[string]interface{}) + } + incomplete.Input = input + } + toolUses = append(toolUses, incomplete) + processedIDs[currentToolUse.ToolUseID] = true + } + currentToolUse = nil + } + + if currentToolUse == nil { + if processedIDs != nil && processedIDs[toolUseID] { + log.Debugf("kiro: skipping duplicate toolUseEvent: %s", toolUseID) + return nil, nil + } + + currentToolUse = &ToolUseState{ + ToolUseID: toolUseID, + Name: toolName, + } + log.Infof("kiro: starting new tool use: %s (ID: %s)", toolName, toolUseID) + } + } + + // Accumulate input fragments + if currentToolUse != nil && inputFragment != "" { + currentToolUse.InputBuffer.WriteString(inputFragment) + log.Debugf("kiro: accumulated input fragment, total length: %d", currentToolUse.InputBuffer.Len()) + } + + // If complete input object provided directly + if currentToolUse != nil && inputMap != nil { + inputBytes, _ := json.Marshal(inputMap) + currentToolUse.InputBuffer.Reset() + currentToolUse.InputBuffer.Write(inputBytes) + } + + // Tool use complete + if isStop && currentToolUse != nil { + fullInput := currentToolUse.InputBuffer.String() + + // Repair and parse the accumulated JSON + repairedJSON := RepairJSON(fullInput) + var finalInput map[string]interface{} + if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil { + log.Warnf("kiro: failed to parse accumulated tool input: %v, raw: %s", err, fullInput) + finalInput = make(map[string]interface{}) + } + + toolUse := KiroToolUse{ + ToolUseID: currentToolUse.ToolUseID, + Name: currentToolUse.Name, + Input: finalInput, + } + toolUses = append(toolUses, toolUse) + + if processedIDs != nil { + processedIDs[currentToolUse.ToolUseID] = true + } + + log.Infof("kiro: completed tool use: %s (ID: %s)", currentToolUse.Name, currentToolUse.ToolUseID) + return toolUses, nil + } + + return toolUses, currentToolUse +} + +// DeduplicateToolUses removes duplicate tool uses based on toolUseId and content. +func DeduplicateToolUses(toolUses []KiroToolUse) []KiroToolUse { + seenIDs := make(map[string]bool) + seenContent := make(map[string]bool) + var unique []KiroToolUse + + for _, tu := range toolUses { + if seenIDs[tu.ToolUseID] { + log.Debugf("kiro: removing ID-duplicate tool use: %s (name: %s)", tu.ToolUseID, tu.Name) + continue + } + + inputJSON, _ := json.Marshal(tu.Input) + contentKey := tu.Name + ":" + string(inputJSON) + + if seenContent[contentKey] { + log.Debugf("kiro: removing content-duplicate tool use: %s (id: %s)", tu.Name, tu.ToolUseID) + continue + } + + seenIDs[tu.ToolUseID] = true + seenContent[contentKey] = true + unique = append(unique, tu) + } + + return unique +} + diff --git a/internal/translator/kiro/common/constants.go b/internal/translator/kiro/common/constants.go new file mode 100644 index 00000000..96174b8c --- /dev/null +++ b/internal/translator/kiro/common/constants.go @@ -0,0 +1,75 @@ +// Package common provides shared constants and utilities for Kiro translator. +package common + +const ( + // KiroMaxToolDescLen is the maximum description length for Kiro API tools. + // Kiro API limit is 10240 bytes, leave room for "..." + KiroMaxToolDescLen = 10237 + + // ThinkingStartTag is the start tag for thinking blocks in responses. + ThinkingStartTag = "" + + // ThinkingEndTag is the end tag for thinking blocks in responses. + ThinkingEndTag = "" + + // CodeFenceMarker is the markdown code fence marker. + CodeFenceMarker = "```" + + // AltCodeFenceMarker is the alternative markdown code fence marker. + AltCodeFenceMarker = "~~~" + + // InlineCodeMarker is the markdown inline code marker (backtick). + InlineCodeMarker = "`" + + // KiroAgenticSystemPrompt is injected only for -agentic models to prevent timeouts on large writes. + // AWS Kiro API has a 2-3 minute timeout for large file write operations. + KiroAgenticSystemPrompt = ` +# CRITICAL: CHUNKED WRITE PROTOCOL (MANDATORY) + +You MUST follow these rules for ALL file operations. Violation causes server timeouts and task failure. + +## ABSOLUTE LIMITS +- **MAXIMUM 350 LINES** per single write/edit operation - NO EXCEPTIONS +- **RECOMMENDED 300 LINES** or less for optimal performance +- **NEVER** write entire files in one operation if >300 lines + +## MANDATORY CHUNKED WRITE STRATEGY + +### For NEW FILES (>300 lines total): +1. FIRST: Write initial chunk (first 250-300 lines) using write_to_file/fsWrite +2. THEN: Append remaining content in 250-300 line chunks using file append operations +3. REPEAT: Continue appending until complete + +### For EDITING EXISTING FILES: +1. Use surgical edits (apply_diff/targeted edits) - change ONLY what's needed +2. NEVER rewrite entire files - use incremental modifications +3. Split large refactors into multiple small, focused edits + +### For LARGE CODE GENERATION: +1. Generate in logical sections (imports, types, functions separately) +2. Write each section as a separate operation +3. Use append operations for subsequent sections + +## EXAMPLES OF CORRECT BEHAVIOR + +✅ CORRECT: Writing a 600-line file +- Operation 1: Write lines 1-300 (initial file creation) +- Operation 2: Append lines 301-600 + +✅ CORRECT: Editing multiple functions +- Operation 1: Edit function A +- Operation 2: Edit function B +- Operation 3: Edit function C + +❌ WRONG: Writing 500 lines in single operation → TIMEOUT +❌ WRONG: Rewriting entire file to change 5 lines → TIMEOUT +❌ WRONG: Generating massive code blocks without chunking → TIMEOUT + +## WHY THIS MATTERS +- Server has 2-3 minute timeout for operations +- Large writes exceed timeout and FAIL completely +- Chunked writes are FASTER and more RELIABLE +- Failed writes waste time and require retry + +REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation.` +) \ No newline at end of file diff --git a/internal/translator/kiro/common/message_merge.go b/internal/translator/kiro/common/message_merge.go new file mode 100644 index 00000000..93f17f28 --- /dev/null +++ b/internal/translator/kiro/common/message_merge.go @@ -0,0 +1,125 @@ +// Package common provides shared utilities for Kiro translators. +package common + +import ( + "encoding/json" + + "github.com/tidwall/gjson" +) + +// MergeAdjacentMessages merges adjacent messages with the same role. +// This reduces API call complexity and improves compatibility. +// Based on AIClient-2-API implementation. +func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result { + if len(messages) <= 1 { + return messages + } + + var merged []gjson.Result + for _, msg := range messages { + if len(merged) == 0 { + merged = append(merged, msg) + continue + } + + lastMsg := merged[len(merged)-1] + currentRole := msg.Get("role").String() + lastRole := lastMsg.Get("role").String() + + if currentRole == lastRole { + // Merge content from current message into last message + mergedContent := mergeMessageContent(lastMsg, msg) + // Create a new merged message JSON + mergedMsg := createMergedMessage(lastRole, mergedContent) + merged[len(merged)-1] = gjson.Parse(mergedMsg) + } else { + merged = append(merged, msg) + } + } + + return merged +} + +// mergeMessageContent merges the content of two messages with the same role. +// Handles both string content and array content (with text, tool_use, tool_result blocks). +func mergeMessageContent(msg1, msg2 gjson.Result) string { + content1 := msg1.Get("content") + content2 := msg2.Get("content") + + // Extract content blocks from both messages + var blocks1, blocks2 []map[string]interface{} + + if content1.IsArray() { + for _, block := range content1.Array() { + blocks1 = append(blocks1, blockToMap(block)) + } + } else if content1.Type == gjson.String { + blocks1 = append(blocks1, map[string]interface{}{ + "type": "text", + "text": content1.String(), + }) + } + + if content2.IsArray() { + for _, block := range content2.Array() { + blocks2 = append(blocks2, blockToMap(block)) + } + } else if content2.Type == gjson.String { + blocks2 = append(blocks2, map[string]interface{}{ + "type": "text", + "text": content2.String(), + }) + } + + // Merge text blocks if both end/start with text + if len(blocks1) > 0 && len(blocks2) > 0 { + if blocks1[len(blocks1)-1]["type"] == "text" && blocks2[0]["type"] == "text" { + // Merge the last text block of msg1 with the first text block of msg2 + text1 := blocks1[len(blocks1)-1]["text"].(string) + text2 := blocks2[0]["text"].(string) + blocks1[len(blocks1)-1]["text"] = text1 + "\n" + text2 + blocks2 = blocks2[1:] // Remove the merged block from blocks2 + } + } + + // Combine all blocks + allBlocks := append(blocks1, blocks2...) + + // Convert to JSON + result, _ := json.Marshal(allBlocks) + return string(result) +} + +// blockToMap converts a gjson.Result block to a map[string]interface{} +func blockToMap(block gjson.Result) map[string]interface{} { + result := make(map[string]interface{}) + block.ForEach(func(key, value gjson.Result) bool { + if value.IsObject() { + result[key.String()] = blockToMap(value) + } else if value.IsArray() { + var arr []interface{} + for _, item := range value.Array() { + if item.IsObject() { + arr = append(arr, blockToMap(item)) + } else { + arr = append(arr, item.Value()) + } + } + result[key.String()] = arr + } else { + result[key.String()] = value.Value() + } + return true + }) + return result +} + +// createMergedMessage creates a JSON string for a merged message +func createMergedMessage(role string, content string) string { + msg := map[string]interface{}{ + "role": role, + "content": json.RawMessage(content), + } + result, _ := json.Marshal(msg) + return string(result) +} \ No newline at end of file diff --git a/internal/translator/kiro/common/utils.go b/internal/translator/kiro/common/utils.go new file mode 100644 index 00000000..f5f5788a --- /dev/null +++ b/internal/translator/kiro/common/utils.go @@ -0,0 +1,16 @@ +// Package common provides shared constants and utilities for Kiro translator. +package common + +// GetString safely extracts a string from a map. +// Returns empty string if the key doesn't exist or the value is not a string. +func GetString(m map[string]interface{}, key string) string { + if v, ok := m[key].(string); ok { + return v + } + return "" +} + +// GetStringValue is an alias for GetString for backward compatibility. +func GetStringValue(m map[string]interface{}, key string) string { + return GetString(m, key) +} \ No newline at end of file diff --git a/internal/translator/kiro/openai/init.go b/internal/translator/kiro/openai/init.go new file mode 100644 index 00000000..653eed45 --- /dev/null +++ b/internal/translator/kiro/openai/init.go @@ -0,0 +1,20 @@ +// Package openai provides translation between OpenAI Chat Completions and Kiro formats. +package openai + +import ( + . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator" +) + +func init() { + translator.Register( + OpenAI, // source format + Kiro, // target format + ConvertOpenAIRequestToKiro, + interfaces.TranslateResponse{ + Stream: ConvertKiroStreamToOpenAI, + NonStream: ConvertKiroNonStreamToOpenAI, + }, + ) +} \ No newline at end of file diff --git a/internal/translator/kiro/openai/kiro_openai.go b/internal/translator/kiro/openai/kiro_openai.go new file mode 100644 index 00000000..d5822998 --- /dev/null +++ b/internal/translator/kiro/openai/kiro_openai.go @@ -0,0 +1,369 @@ +// Package openai provides translation between OpenAI Chat Completions and Kiro formats. +// This package enables direct OpenAI → Kiro translation, bypassing the Claude intermediate layer. +// +// The Kiro executor generates Claude-compatible SSE format internally, so the streaming response +// translation converts from Claude SSE format to OpenAI SSE format. +package openai + +import ( + "bytes" + "context" + "encoding/json" + "strings" + + kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" +) + +// ConvertKiroStreamToOpenAI converts Kiro streaming response to OpenAI format. +// The Kiro executor emits Claude-compatible SSE events, so this function translates +// from Claude SSE format to OpenAI SSE format. +// +// Claude SSE format: +// - event: message_start\ndata: {...} +// - event: content_block_start\ndata: {...} +// - event: content_block_delta\ndata: {...} +// - event: content_block_stop\ndata: {...} +// - event: message_delta\ndata: {...} +// - event: message_stop\ndata: {...} +// +// OpenAI SSE format: +// - data: {"id":"...","object":"chat.completion.chunk",...} +// - data: [DONE] +func ConvertKiroStreamToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string { + // Initialize state if needed + if *param == nil { + *param = NewOpenAIStreamState(model) + } + state := (*param).(*OpenAIStreamState) + + // Parse the Claude SSE event + responseStr := string(rawResponse) + + // Handle raw event format (event: xxx\ndata: {...}) + var eventType string + var eventData string + + if strings.HasPrefix(responseStr, "event:") { + // Parse event type and data + lines := strings.SplitN(responseStr, "\n", 2) + if len(lines) >= 1 { + eventType = strings.TrimSpace(strings.TrimPrefix(lines[0], "event:")) + } + if len(lines) >= 2 && strings.HasPrefix(lines[1], "data:") { + eventData = strings.TrimSpace(strings.TrimPrefix(lines[1], "data:")) + } + } else if strings.HasPrefix(responseStr, "data:") { + // Just data line + eventData = strings.TrimSpace(strings.TrimPrefix(responseStr, "data:")) + } else { + // Try to parse as raw JSON + eventData = strings.TrimSpace(responseStr) + } + + if eventData == "" { + return []string{} + } + + // Parse the event data as JSON + eventJSON := gjson.Parse(eventData) + if !eventJSON.Exists() { + return []string{} + } + + // Determine event type from JSON if not already set + if eventType == "" { + eventType = eventJSON.Get("type").String() + } + + var results []string + + switch eventType { + case "message_start": + // Send first chunk with role + firstChunk := BuildOpenAISSEFirstChunk(state) + results = append(results, firstChunk) + + case "content_block_start": + // Check block type + blockType := eventJSON.Get("content_block.type").String() + switch blockType { + case "text": + // Text block starting - nothing to emit yet + case "thinking": + // Thinking block starting - nothing to emit yet for OpenAI + case "tool_use": + // Tool use block starting + toolUseID := eventJSON.Get("content_block.id").String() + toolName := eventJSON.Get("content_block.name").String() + chunk := BuildOpenAISSEToolCallStart(state, toolUseID, toolName) + results = append(results, chunk) + state.ToolCallIndex++ + } + + case "content_block_delta": + deltaType := eventJSON.Get("delta.type").String() + switch deltaType { + case "text_delta": + textDelta := eventJSON.Get("delta.text").String() + if textDelta != "" { + chunk := BuildOpenAISSETextDelta(state, textDelta) + results = append(results, chunk) + } + case "thinking_delta": + // Convert thinking to reasoning_content for o1-style compatibility + thinkingDelta := eventJSON.Get("delta.thinking").String() + if thinkingDelta != "" { + chunk := BuildOpenAISSEReasoningDelta(state, thinkingDelta) + results = append(results, chunk) + } + case "input_json_delta": + // Tool call arguments delta + partialJSON := eventJSON.Get("delta.partial_json").String() + if partialJSON != "" { + // Get the tool index from content block index + blockIndex := int(eventJSON.Get("index").Int()) + chunk := BuildOpenAISSEToolCallArgumentsDelta(state, partialJSON, blockIndex-1) // Adjust for 0-based tool index + results = append(results, chunk) + } + } + + case "content_block_stop": + // Content block ended - nothing to emit for OpenAI + + case "message_delta": + // Message delta with stop_reason + stopReason := eventJSON.Get("delta.stop_reason").String() + finishReason := mapKiroStopReasonToOpenAI(stopReason) + if finishReason != "" { + chunk := BuildOpenAISSEFinish(state, finishReason) + results = append(results, chunk) + } + + // Extract usage if present + if eventJSON.Get("usage").Exists() { + inputTokens := eventJSON.Get("usage.input_tokens").Int() + outputTokens := eventJSON.Get("usage.output_tokens").Int() + usageInfo := usage.Detail{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + TotalTokens: inputTokens + outputTokens, + } + chunk := BuildOpenAISSEUsage(state, usageInfo) + results = append(results, chunk) + } + + case "message_stop": + // Final event - do NOT emit [DONE] here + // The handler layer (openai_handlers.go) will send [DONE] when the stream closes + // Emitting [DONE] here would cause duplicate [DONE] markers + + case "ping": + // Ping event with usage - optionally emit usage chunk + if eventJSON.Get("usage").Exists() { + inputTokens := eventJSON.Get("usage.input_tokens").Int() + outputTokens := eventJSON.Get("usage.output_tokens").Int() + usageInfo := usage.Detail{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + TotalTokens: inputTokens + outputTokens, + } + chunk := BuildOpenAISSEUsage(state, usageInfo) + results = append(results, chunk) + } + } + + return results +} + +// ConvertKiroNonStreamToOpenAI converts Kiro non-streaming response to OpenAI format. +// The Kiro executor returns Claude-compatible JSON responses, so this function translates +// from Claude format to OpenAI format. +func ConvertKiroNonStreamToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string { + // Parse the Claude-format response + response := gjson.ParseBytes(rawResponse) + + // Extract content + var content string + var toolUses []KiroToolUse + var stopReason string + + // Get stop_reason + stopReason = response.Get("stop_reason").String() + + // Process content blocks + contentBlocks := response.Get("content") + if contentBlocks.IsArray() { + for _, block := range contentBlocks.Array() { + blockType := block.Get("type").String() + switch blockType { + case "text": + content += block.Get("text").String() + case "thinking": + // Skip thinking blocks for OpenAI format (or convert to reasoning_content if needed) + case "tool_use": + toolUseID := block.Get("id").String() + toolName := block.Get("name").String() + toolInput := block.Get("input") + + var inputMap map[string]interface{} + if toolInput.IsObject() { + inputMap = make(map[string]interface{}) + toolInput.ForEach(func(key, value gjson.Result) bool { + inputMap[key.String()] = value.Value() + return true + }) + } + + toolUses = append(toolUses, KiroToolUse{ + ToolUseID: toolUseID, + Name: toolName, + Input: inputMap, + }) + } + } + } + + // Extract usage + usageInfo := usage.Detail{ + InputTokens: response.Get("usage.input_tokens").Int(), + OutputTokens: response.Get("usage.output_tokens").Int(), + } + usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens + + // Build OpenAI response + openaiResponse := BuildOpenAIResponse(content, toolUses, model, usageInfo, stopReason) + return string(openaiResponse) +} + +// ParseClaudeEvent parses a Claude SSE event and returns the event type and data +func ParseClaudeEvent(rawEvent []byte) (eventType string, eventData []byte) { + lines := bytes.Split(rawEvent, []byte("\n")) + for _, line := range lines { + line = bytes.TrimSpace(line) + if bytes.HasPrefix(line, []byte("event:")) { + eventType = string(bytes.TrimSpace(bytes.TrimPrefix(line, []byte("event:")))) + } else if bytes.HasPrefix(line, []byte("data:")) { + eventData = bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:"))) + } + } + return eventType, eventData +} + +// ExtractThinkingFromContent parses content to extract thinking blocks. +// Returns cleaned content (without thinking tags) and whether thinking was found. +func ExtractThinkingFromContent(content string) (string, string, bool) { + if !strings.Contains(content, kirocommon.ThinkingStartTag) { + return content, "", false + } + + var cleanedContent strings.Builder + var thinkingContent strings.Builder + hasThinking := false + remaining := content + + for len(remaining) > 0 { + startIdx := strings.Index(remaining, kirocommon.ThinkingStartTag) + if startIdx == -1 { + cleanedContent.WriteString(remaining) + break + } + + // Add content before thinking tag + cleanedContent.WriteString(remaining[:startIdx]) + + // Move past opening tag + remaining = remaining[startIdx+len(kirocommon.ThinkingStartTag):] + + // Find closing tag + endIdx := strings.Index(remaining, kirocommon.ThinkingEndTag) + if endIdx == -1 { + // No closing tag - treat rest as thinking + thinkingContent.WriteString(remaining) + hasThinking = true + break + } + + // Extract thinking content + thinkingContent.WriteString(remaining[:endIdx]) + hasThinking = true + remaining = remaining[endIdx+len(kirocommon.ThinkingEndTag):] + } + + return strings.TrimSpace(cleanedContent.String()), strings.TrimSpace(thinkingContent.String()), hasThinking +} + +// ConvertOpenAIToolsToKiroFormat is a helper that converts OpenAI tools format to Kiro format +func ConvertOpenAIToolsToKiroFormat(tools []map[string]interface{}) []KiroToolWrapper { + var kiroTools []KiroToolWrapper + + for _, tool := range tools { + toolType, _ := tool["type"].(string) + if toolType != "function" { + continue + } + + fn, ok := tool["function"].(map[string]interface{}) + if !ok { + continue + } + + name := kirocommon.GetString(fn, "name") + description := kirocommon.GetString(fn, "description") + parameters := fn["parameters"] + + if name == "" { + continue + } + + if description == "" { + description = "Tool: " + name + } + + kiroTools = append(kiroTools, KiroToolWrapper{ + ToolSpecification: KiroToolSpecification{ + Name: name, + Description: description, + InputSchema: KiroInputSchema{JSON: parameters}, + }, + }) + } + + return kiroTools +} + +// OpenAIStreamParams holds parameters for OpenAI streaming conversion +type OpenAIStreamParams struct { + State *OpenAIStreamState + ThinkingState *ThinkingTagState + ToolCallsEmitted map[string]bool +} + +// NewOpenAIStreamParams creates new streaming parameters +func NewOpenAIStreamParams(model string) *OpenAIStreamParams { + return &OpenAIStreamParams{ + State: NewOpenAIStreamState(model), + ThinkingState: NewThinkingTagState(), + ToolCallsEmitted: make(map[string]bool), + } +} + +// ConvertClaudeToolUseToOpenAI converts a Claude tool_use block to OpenAI tool_calls format +func ConvertClaudeToolUseToOpenAI(toolUseID, toolName string, input map[string]interface{}) map[string]interface{} { + inputJSON, _ := json.Marshal(input) + return map[string]interface{}{ + "id": toolUseID, + "type": "function", + "function": map[string]interface{}{ + "name": toolName, + "arguments": string(inputJSON), + }, + } +} + +// LogStreamEvent logs a streaming event for debugging +func LogStreamEvent(eventType, data string) { + log.Debugf("kiro-openai: stream event type=%s, data_len=%d", eventType, len(data)) +} \ No newline at end of file diff --git a/internal/translator/kiro/openai/kiro_openai_request.go b/internal/translator/kiro/openai/kiro_openai_request.go new file mode 100644 index 00000000..cb97a340 --- /dev/null +++ b/internal/translator/kiro/openai/kiro_openai_request.go @@ -0,0 +1,848 @@ +// Package openai provides request translation from OpenAI Chat Completions to Kiro format. +// It handles parsing and transforming OpenAI API requests into the Kiro/Amazon Q API format, +// extracting model information, system instructions, message contents, and tool declarations. +package openai + +import ( + "encoding/json" + "fmt" + "strings" + "time" + "unicode/utf8" + + "github.com/google/uuid" + kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" +) + +// Kiro API request structs - reuse from kiroclaude package structure + +// KiroPayload is the top-level request structure for Kiro API +type KiroPayload struct { + ConversationState KiroConversationState `json:"conversationState"` + ProfileArn string `json:"profileArn,omitempty"` + InferenceConfig *KiroInferenceConfig `json:"inferenceConfig,omitempty"` +} + +// KiroInferenceConfig contains inference parameters for the Kiro API. +type KiroInferenceConfig struct { + MaxTokens int `json:"maxTokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` +} + +// KiroConversationState holds the conversation context +type KiroConversationState struct { + ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" + ConversationID string `json:"conversationId"` + CurrentMessage KiroCurrentMessage `json:"currentMessage"` + History []KiroHistoryMessage `json:"history,omitempty"` +} + +// KiroCurrentMessage wraps the current user message +type KiroCurrentMessage struct { + UserInputMessage KiroUserInputMessage `json:"userInputMessage"` +} + +// KiroHistoryMessage represents a message in the conversation history +type KiroHistoryMessage struct { + UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"` + AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"` +} + +// KiroImage represents an image in Kiro API format +type KiroImage struct { + Format string `json:"format"` + Source KiroImageSource `json:"source"` +} + +// KiroImageSource contains the image data +type KiroImageSource struct { + Bytes string `json:"bytes"` // base64 encoded image data +} + +// KiroUserInputMessage represents a user message +type KiroUserInputMessage struct { + Content string `json:"content"` + ModelID string `json:"modelId"` + Origin string `json:"origin"` + Images []KiroImage `json:"images,omitempty"` + UserInputMessageContext *KiroUserInputMessageContext `json:"userInputMessageContext,omitempty"` +} + +// KiroUserInputMessageContext contains tool-related context +type KiroUserInputMessageContext struct { + ToolResults []KiroToolResult `json:"toolResults,omitempty"` + Tools []KiroToolWrapper `json:"tools,omitempty"` +} + +// KiroToolResult represents a tool execution result +type KiroToolResult struct { + Content []KiroTextContent `json:"content"` + Status string `json:"status"` + ToolUseID string `json:"toolUseId"` +} + +// KiroTextContent represents text content +type KiroTextContent struct { + Text string `json:"text"` +} + +// KiroToolWrapper wraps a tool specification +type KiroToolWrapper struct { + ToolSpecification KiroToolSpecification `json:"toolSpecification"` +} + +// KiroToolSpecification defines a tool's schema +type KiroToolSpecification struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema KiroInputSchema `json:"inputSchema"` +} + +// KiroInputSchema wraps the JSON schema for tool input +type KiroInputSchema struct { + JSON interface{} `json:"json"` +} + +// KiroAssistantResponseMessage represents an assistant message +type KiroAssistantResponseMessage struct { + Content string `json:"content"` + ToolUses []KiroToolUse `json:"toolUses,omitempty"` +} + +// KiroToolUse represents a tool invocation by the assistant +type KiroToolUse struct { + ToolUseID string `json:"toolUseId"` + Name string `json:"name"` + Input map[string]interface{} `json:"input"` +} + +// ConvertOpenAIRequestToKiro converts an OpenAI Chat Completions request to Kiro format. +// This is the main entry point for request translation. +// Note: The actual payload building happens in the executor, this just passes through +// the OpenAI format which will be converted by BuildKiroPayloadFromOpenAI. +func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte { + // Pass through the OpenAI format - actual conversion happens in BuildKiroPayloadFromOpenAI + return inputRawJSON +} + +// BuildKiroPayloadFromOpenAI constructs the Kiro API request payload from OpenAI format. +// Supports tool calling - tools are passed via userInputMessageContext. +// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE. +// isAgentic parameter enables chunked write optimization prompt for -agentic model variants. +// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode). +// Returns the payload and a boolean indicating whether thinking mode was injected. +func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) ([]byte, bool) { + // Extract max_tokens for potential use in inferenceConfig + // Handle -1 as "use maximum" (Kiro max output is ~32000 tokens) + const kiroMaxOutputTokens = 32000 + var maxTokens int64 + if mt := gjson.GetBytes(openaiBody, "max_tokens"); mt.Exists() { + maxTokens = mt.Int() + if maxTokens == -1 { + maxTokens = kiroMaxOutputTokens + log.Debugf("kiro-openai: max_tokens=-1 converted to %d", kiroMaxOutputTokens) + } + } + + // Extract temperature if specified + var temperature float64 + var hasTemperature bool + if temp := gjson.GetBytes(openaiBody, "temperature"); temp.Exists() { + temperature = temp.Float() + hasTemperature = true + } + + // Extract top_p if specified + var topP float64 + var hasTopP bool + if tp := gjson.GetBytes(openaiBody, "top_p"); tp.Exists() { + topP = tp.Float() + hasTopP = true + log.Debugf("kiro-openai: extracted top_p: %.2f", topP) + } + + // Normalize origin value for Kiro API compatibility + origin = normalizeOrigin(origin) + log.Debugf("kiro-openai: normalized origin value: %s", origin) + + messages := gjson.GetBytes(openaiBody, "messages") + + // For chat-only mode, don't include tools + var tools gjson.Result + if !isChatOnly { + tools = gjson.GetBytes(openaiBody, "tools") + } + + // Extract system prompt from messages + systemPrompt := extractSystemPromptFromOpenAI(messages) + + // Inject timestamp context + timestamp := time.Now().Format("2006-01-02 15:04:05 MST") + timestampContext := fmt.Sprintf("[Context: Current time is %s]", timestamp) + if systemPrompt != "" { + systemPrompt = timestampContext + "\n\n" + systemPrompt + } else { + systemPrompt = timestampContext + } + log.Debugf("kiro-openai: injected timestamp context: %s", timestamp) + + // Inject agentic optimization prompt for -agentic model variants + if isAgentic { + if systemPrompt != "" { + systemPrompt += "\n" + } + systemPrompt += kirocommon.KiroAgenticSystemPrompt + } + + // Handle tool_choice parameter - Kiro doesn't support it natively, so we inject system prompt hints + // OpenAI tool_choice values: "none", "auto", "required", or {"type":"function","function":{"name":"..."}} + toolChoiceHint := extractToolChoiceHint(openaiBody) + if toolChoiceHint != "" { + if systemPrompt != "" { + systemPrompt += "\n" + } + systemPrompt += toolChoiceHint + log.Debugf("kiro-openai: injected tool_choice hint into system prompt") + } + + // Handle response_format parameter - Kiro doesn't support it natively, so we inject system prompt hints + // OpenAI response_format: {"type": "json_object"} or {"type": "json_schema", "json_schema": {...}} + responseFormatHint := extractResponseFormatHint(openaiBody) + if responseFormatHint != "" { + if systemPrompt != "" { + systemPrompt += "\n" + } + systemPrompt += responseFormatHint + log.Debugf("kiro-openai: injected response_format hint into system prompt") + } + + // Check for thinking mode and inject thinking hint + // Supports OpenAI reasoning_effort parameter and model name hints + thinkingEnabled, budgetTokens := checkThinkingModeFromOpenAI(openaiBody) + if thinkingEnabled { + // Adjust budgetTokens based on max_tokens if not explicitly set by reasoning_effort + // Use 50% of max_tokens for thinking, with min 8000 and max 24000 + if maxTokens > 0 && budgetTokens == 16000 { // 16000 is the default, meaning not explicitly set + calculatedBudget := maxTokens / 2 + if calculatedBudget < 8000 { + calculatedBudget = 8000 + } + if calculatedBudget > 24000 { + calculatedBudget = 24000 + } + budgetTokens = calculatedBudget + log.Debugf("kiro-openai: budgetTokens calculated from max_tokens: %d (max_tokens=%d)", budgetTokens, maxTokens) + } + + if systemPrompt != "" { + systemPrompt += "\n" + } + dynamicThinkingHint := fmt.Sprintf("interleaved%d", budgetTokens) + systemPrompt += dynamicThinkingHint + log.Debugf("kiro-openai: injected dynamic thinking hint into system prompt, max_thinking_length: %d", budgetTokens) + } + + // Convert OpenAI tools to Kiro format + kiroTools := convertOpenAIToolsToKiro(tools) + + // Process messages and build history + history, currentUserMsg, currentToolResults := processOpenAIMessages(messages, modelID, origin) + + // Build content with system prompt + if currentUserMsg != nil { + currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, systemPrompt, currentToolResults) + + // Deduplicate currentToolResults + currentToolResults = deduplicateToolResults(currentToolResults) + + // Build userInputMessageContext with tools and tool results + if len(kiroTools) > 0 || len(currentToolResults) > 0 { + currentUserMsg.UserInputMessageContext = &KiroUserInputMessageContext{ + Tools: kiroTools, + ToolResults: currentToolResults, + } + } + } + + // Build payload + var currentMessage KiroCurrentMessage + if currentUserMsg != nil { + currentMessage = KiroCurrentMessage{UserInputMessage: *currentUserMsg} + } else { + fallbackContent := "" + if systemPrompt != "" { + fallbackContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n" + } + currentMessage = KiroCurrentMessage{UserInputMessage: KiroUserInputMessage{ + Content: fallbackContent, + ModelID: modelID, + Origin: origin, + }} + } + + // Build inferenceConfig if we have any inference parameters + var inferenceConfig *KiroInferenceConfig + if maxTokens > 0 || hasTemperature || hasTopP { + inferenceConfig = &KiroInferenceConfig{} + if maxTokens > 0 { + inferenceConfig.MaxTokens = int(maxTokens) + } + if hasTemperature { + inferenceConfig.Temperature = temperature + } + if hasTopP { + inferenceConfig.TopP = topP + } + } + + payload := KiroPayload{ + ConversationState: KiroConversationState{ + ChatTriggerType: "MANUAL", + ConversationID: uuid.New().String(), + CurrentMessage: currentMessage, + History: history, + }, + ProfileArn: profileArn, + InferenceConfig: inferenceConfig, + } + + result, err := json.Marshal(payload) + if err != nil { + log.Debugf("kiro-openai: failed to marshal payload: %v", err) + return nil, false + } + + return result, thinkingEnabled +} + +// normalizeOrigin normalizes origin value for Kiro API compatibility +func normalizeOrigin(origin string) string { + switch origin { + case "KIRO_CLI": + return "CLI" + case "KIRO_AI_EDITOR": + return "AI_EDITOR" + case "AMAZON_Q": + return "CLI" + case "KIRO_IDE": + return "AI_EDITOR" + default: + return origin + } +} + +// extractSystemPromptFromOpenAI extracts system prompt from OpenAI messages +func extractSystemPromptFromOpenAI(messages gjson.Result) string { + if !messages.IsArray() { + return "" + } + + var systemParts []string + for _, msg := range messages.Array() { + if msg.Get("role").String() == "system" { + content := msg.Get("content") + if content.Type == gjson.String { + systemParts = append(systemParts, content.String()) + } else if content.IsArray() { + // Handle array content format + for _, part := range content.Array() { + if part.Get("type").String() == "text" { + systemParts = append(systemParts, part.Get("text").String()) + } + } + } + } + } + + return strings.Join(systemParts, "\n") +} + +// shortenToolNameIfNeeded shortens tool names that exceed 64 characters. +// MCP tools often have long names like "mcp__server-name__tool-name". +// This preserves the "mcp__" prefix and last segment when possible. +func shortenToolNameIfNeeded(name string) string { + const limit = 64 + if len(name) <= limit { + return name + } + // For MCP tools, try to preserve prefix and last segment + if strings.HasPrefix(name, "mcp__") { + idx := strings.LastIndex(name, "__") + if idx > 0 { + cand := "mcp__" + name[idx+2:] + if len(cand) > limit { + return cand[:limit] + } + return cand + } + } + return name[:limit] +} + +// convertOpenAIToolsToKiro converts OpenAI tools to Kiro format +func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper { + var kiroTools []KiroToolWrapper + if !tools.IsArray() { + return kiroTools + } + + for _, tool := range tools.Array() { + // OpenAI tools have type "function" with function definition inside + if tool.Get("type").String() != "function" { + continue + } + + fn := tool.Get("function") + if !fn.Exists() { + continue + } + + name := fn.Get("name").String() + description := fn.Get("description").String() + parameters := fn.Get("parameters").Value() + + // Shorten tool name if it exceeds 64 characters (common with MCP tools) + originalName := name + name = shortenToolNameIfNeeded(name) + if name != originalName { + log.Debugf("kiro-openai: shortened tool name from '%s' to '%s'", originalName, name) + } + + // CRITICAL FIX: Kiro API requires non-empty description + if strings.TrimSpace(description) == "" { + description = fmt.Sprintf("Tool: %s", name) + log.Debugf("kiro-openai: tool '%s' has empty description, using default: %s", name, description) + } + + // Truncate long descriptions + if len(description) > kirocommon.KiroMaxToolDescLen { + truncLen := kirocommon.KiroMaxToolDescLen - 30 + for truncLen > 0 && !utf8.RuneStart(description[truncLen]) { + truncLen-- + } + description = description[:truncLen] + "... (description truncated)" + } + + kiroTools = append(kiroTools, KiroToolWrapper{ + ToolSpecification: KiroToolSpecification{ + Name: name, + Description: description, + InputSchema: KiroInputSchema{JSON: parameters}, + }, + }) + } + + return kiroTools +} + +// processOpenAIMessages processes OpenAI messages and builds Kiro history +func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]KiroHistoryMessage, *KiroUserInputMessage, []KiroToolResult) { + var history []KiroHistoryMessage + var currentUserMsg *KiroUserInputMessage + var currentToolResults []KiroToolResult + + if !messages.IsArray() { + return history, currentUserMsg, currentToolResults + } + + // Merge adjacent messages with the same role + messagesArray := kirocommon.MergeAdjacentMessages(messages.Array()) + + // Build tool_call_id to name mapping from assistant messages + toolCallIDToName := make(map[string]string) + for _, msg := range messagesArray { + if msg.Get("role").String() == "assistant" { + toolCalls := msg.Get("tool_calls") + if toolCalls.IsArray() { + for _, tc := range toolCalls.Array() { + if tc.Get("type").String() == "function" { + id := tc.Get("id").String() + name := tc.Get("function.name").String() + if id != "" && name != "" { + toolCallIDToName[id] = name + } + } + } + } + } + } + + for i, msg := range messagesArray { + role := msg.Get("role").String() + isLastMessage := i == len(messagesArray)-1 + + switch role { + case "system": + // System messages are handled separately via extractSystemPromptFromOpenAI + continue + + case "user": + userMsg, toolResults := buildUserMessageFromOpenAI(msg, modelID, origin) + if isLastMessage { + currentUserMsg = &userMsg + currentToolResults = toolResults + } else { + // CRITICAL: Kiro API requires content to be non-empty for history messages + if strings.TrimSpace(userMsg.Content) == "" { + if len(toolResults) > 0 { + userMsg.Content = "Tool results provided." + } else { + userMsg.Content = "Continue" + } + } + // For history messages, embed tool results in context + if len(toolResults) > 0 { + userMsg.UserInputMessageContext = &KiroUserInputMessageContext{ + ToolResults: toolResults, + } + } + history = append(history, KiroHistoryMessage{ + UserInputMessage: &userMsg, + }) + } + + case "assistant": + assistantMsg := buildAssistantMessageFromOpenAI(msg) + if isLastMessage { + history = append(history, KiroHistoryMessage{ + AssistantResponseMessage: &assistantMsg, + }) + // Create a "Continue" user message as currentMessage + currentUserMsg = &KiroUserInputMessage{ + Content: "Continue", + ModelID: modelID, + Origin: origin, + } + } else { + history = append(history, KiroHistoryMessage{ + AssistantResponseMessage: &assistantMsg, + }) + } + + case "tool": + // Tool messages in OpenAI format provide results for tool_calls + // These are typically followed by user or assistant messages + // Process them and merge into the next user message's tool results + toolCallID := msg.Get("tool_call_id").String() + content := msg.Get("content").String() + + if toolCallID != "" { + toolResult := KiroToolResult{ + ToolUseID: toolCallID, + Content: []KiroTextContent{{Text: content}}, + Status: "success", + } + // Tool results should be included in the next user message + // For now, collect them and they'll be handled when we build the current message + currentToolResults = append(currentToolResults, toolResult) + } + } + } + + return history, currentUserMsg, currentToolResults +} + +// buildUserMessageFromOpenAI builds a user message from OpenAI format and extracts tool results +func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) { + content := msg.Get("content") + var contentBuilder strings.Builder + var toolResults []KiroToolResult + var images []KiroImage + + // Track seen toolCallIds to deduplicate + seenToolCallIDs := make(map[string]bool) + + if content.IsArray() { + for _, part := range content.Array() { + partType := part.Get("type").String() + switch partType { + case "text": + contentBuilder.WriteString(part.Get("text").String()) + case "image_url": + imageURL := part.Get("image_url.url").String() + if strings.HasPrefix(imageURL, "data:") { + // Parse data URL: data:image/png;base64,xxxxx + if idx := strings.Index(imageURL, ";base64,"); idx != -1 { + mediaType := imageURL[5:idx] // Skip "data:" + data := imageURL[idx+8:] // Skip ";base64," + + format := "" + if lastSlash := strings.LastIndex(mediaType, "/"); lastSlash != -1 { + format = mediaType[lastSlash+1:] + } + + if format != "" && data != "" { + images = append(images, KiroImage{ + Format: format, + Source: KiroImageSource{ + Bytes: data, + }, + }) + } + } + } + } + } + } else if content.Type == gjson.String { + contentBuilder.WriteString(content.String()) + } + + // Check for tool_calls in the message (shouldn't be in user messages, but handle edge cases) + _ = seenToolCallIDs // Used for deduplication if needed + + userMsg := KiroUserInputMessage{ + Content: contentBuilder.String(), + ModelID: modelID, + Origin: origin, + } + + if len(images) > 0 { + userMsg.Images = images + } + + return userMsg, toolResults +} + +// buildAssistantMessageFromOpenAI builds an assistant message from OpenAI format +func buildAssistantMessageFromOpenAI(msg gjson.Result) KiroAssistantResponseMessage { + content := msg.Get("content") + var contentBuilder strings.Builder + var toolUses []KiroToolUse + + // Handle content + if content.Type == gjson.String { + contentBuilder.WriteString(content.String()) + } else if content.IsArray() { + for _, part := range content.Array() { + if part.Get("type").String() == "text" { + contentBuilder.WriteString(part.Get("text").String()) + } + } + } + + // Handle tool_calls + toolCalls := msg.Get("tool_calls") + if toolCalls.IsArray() { + for _, tc := range toolCalls.Array() { + if tc.Get("type").String() != "function" { + continue + } + + toolUseID := tc.Get("id").String() + toolName := tc.Get("function.name").String() + toolArgs := tc.Get("function.arguments").String() + + var inputMap map[string]interface{} + if err := json.Unmarshal([]byte(toolArgs), &inputMap); err != nil { + log.Debugf("kiro-openai: failed to parse tool arguments: %v", err) + inputMap = make(map[string]interface{}) + } + + toolUses = append(toolUses, KiroToolUse{ + ToolUseID: toolUseID, + Name: toolName, + Input: inputMap, + }) + } + } + + return KiroAssistantResponseMessage{ + Content: contentBuilder.String(), + ToolUses: toolUses, + } +} + +// buildFinalContent builds the final content with system prompt +func buildFinalContent(content, systemPrompt string, toolResults []KiroToolResult) string { + var contentBuilder strings.Builder + + if systemPrompt != "" { + contentBuilder.WriteString("--- SYSTEM PROMPT ---\n") + contentBuilder.WriteString(systemPrompt) + contentBuilder.WriteString("\n--- END SYSTEM PROMPT ---\n\n") + } + + contentBuilder.WriteString(content) + finalContent := contentBuilder.String() + + // CRITICAL: Kiro API requires content to be non-empty + if strings.TrimSpace(finalContent) == "" { + if len(toolResults) > 0 { + finalContent = "Tool results provided." + } else { + finalContent = "Continue" + } + log.Debugf("kiro-openai: content was empty, using default: %s", finalContent) + } + + return finalContent +} + +// checkThinkingModeFromOpenAI checks if thinking mode is enabled in the OpenAI request. +// Returns (thinkingEnabled, budgetTokens). +// Supports: +// - reasoning_effort parameter (low/medium/high/auto) +// - Model name containing "thinking" or "reason" +// - tag in system prompt (AMP/Cursor format) +func checkThinkingModeFromOpenAI(openaiBody []byte) (bool, int64) { + var budgetTokens int64 = 16000 // Default budget + + // Check OpenAI format: reasoning_effort parameter + // Valid values: "low", "medium", "high", "auto" (not "none") + reasoningEffort := gjson.GetBytes(openaiBody, "reasoning_effort") + if reasoningEffort.Exists() { + effort := reasoningEffort.String() + if effort != "" && effort != "none" { + log.Debugf("kiro-openai: thinking mode enabled via reasoning_effort: %s", effort) + // Adjust budget based on effort level + switch effort { + case "low": + budgetTokens = 8000 + case "medium": + budgetTokens = 16000 + case "high": + budgetTokens = 32000 + case "auto": + budgetTokens = 16000 + } + return true, budgetTokens + } + } + + // Check AMP/Cursor format: interleaved in system prompt + bodyStr := string(openaiBody) + if strings.Contains(bodyStr, "") && strings.Contains(bodyStr, "") { + startTag := "" + endTag := "" + startIdx := strings.Index(bodyStr, startTag) + if startIdx >= 0 { + startIdx += len(startTag) + endIdx := strings.Index(bodyStr[startIdx:], endTag) + if endIdx >= 0 { + thinkingMode := bodyStr[startIdx : startIdx+endIdx] + if thinkingMode == "interleaved" || thinkingMode == "enabled" { + log.Debugf("kiro-openai: thinking mode enabled via AMP/Cursor format: %s", thinkingMode) + // Try to extract max_thinking_length if present + if maxLenStart := strings.Index(bodyStr, ""); maxLenStart >= 0 { + maxLenStart += len("") + if maxLenEnd := strings.Index(bodyStr[maxLenStart:], ""); maxLenEnd >= 0 { + maxLenStr := bodyStr[maxLenStart : maxLenStart+maxLenEnd] + if parsed, err := fmt.Sscanf(maxLenStr, "%d", &budgetTokens); err == nil && parsed == 1 { + log.Debugf("kiro-openai: extracted max_thinking_length: %d", budgetTokens) + } + } + } + return true, budgetTokens + } + } + } + } + + // Check model name for thinking hints + model := gjson.GetBytes(openaiBody, "model").String() + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, "thinking") || strings.Contains(modelLower, "-reason") { + log.Debugf("kiro-openai: thinking mode enabled via model name hint: %s", model) + return true, budgetTokens + } + + log.Debugf("kiro-openai: no thinking mode detected in OpenAI request") + return false, budgetTokens +} + +// extractToolChoiceHint extracts tool_choice from OpenAI request and returns a system prompt hint. +// OpenAI tool_choice values: +// - "none": Don't use any tools +// - "auto": Model decides (default, no hint needed) +// - "required": Must use at least one tool +// - {"type":"function","function":{"name":"..."}} : Must use specific tool +func extractToolChoiceHint(openaiBody []byte) string { + toolChoice := gjson.GetBytes(openaiBody, "tool_choice") + if !toolChoice.Exists() { + return "" + } + + // Handle string values + if toolChoice.Type == gjson.String { + switch toolChoice.String() { + case "none": + // Note: When tool_choice is "none", we should ideally not pass tools at all + // But since we can't modify tool passing here, we add a strong hint + return "[INSTRUCTION: Do NOT use any tools. Respond with text only.]" + case "required": + return "[INSTRUCTION: You MUST use at least one of the available tools to respond. Do not respond with text only - always make a tool call.]" + case "auto": + // Default behavior, no hint needed + return "" + } + } + + // Handle object value: {"type":"function","function":{"name":"..."}} + if toolChoice.IsObject() { + if toolChoice.Get("type").String() == "function" { + toolName := toolChoice.Get("function.name").String() + if toolName != "" { + return fmt.Sprintf("[INSTRUCTION: You MUST use the tool named '%s' to respond. Do not use any other tool or respond with text only.]", toolName) + } + } + } + + return "" +} + +// extractResponseFormatHint extracts response_format from OpenAI request and returns a system prompt hint. +// OpenAI response_format values: +// - {"type": "text"}: Default, no hint needed +// - {"type": "json_object"}: Must respond with valid JSON +// - {"type": "json_schema", "json_schema": {...}}: Must respond with JSON matching schema +func extractResponseFormatHint(openaiBody []byte) string { + responseFormat := gjson.GetBytes(openaiBody, "response_format") + if !responseFormat.Exists() { + return "" + } + + formatType := responseFormat.Get("type").String() + switch formatType { + case "json_object": + return "[INSTRUCTION: You MUST respond with valid JSON only. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]" + case "json_schema": + // Extract schema if provided + schema := responseFormat.Get("json_schema.schema") + if schema.Exists() { + schemaStr := schema.Raw + // Truncate if too long + if len(schemaStr) > 500 { + schemaStr = schemaStr[:500] + "..." + } + return fmt.Sprintf("[INSTRUCTION: You MUST respond with valid JSON that matches this schema: %s. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]", schemaStr) + } + return "[INSTRUCTION: You MUST respond with valid JSON only. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]" + case "text": + // Default behavior, no hint needed + return "" + } + + return "" +} + +// deduplicateToolResults removes duplicate tool results +func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult { + if len(toolResults) == 0 { + return toolResults + } + + seenIDs := make(map[string]bool) + unique := make([]KiroToolResult, 0, len(toolResults)) + for _, tr := range toolResults { + if !seenIDs[tr.ToolUseID] { + seenIDs[tr.ToolUseID] = true + unique = append(unique, tr) + } else { + log.Debugf("kiro-openai: skipping duplicate toolResult: %s", tr.ToolUseID) + } + } + return unique +} \ No newline at end of file diff --git a/internal/translator/kiro/openai/kiro_openai_response.go b/internal/translator/kiro/openai/kiro_openai_response.go new file mode 100644 index 00000000..b7da1373 --- /dev/null +++ b/internal/translator/kiro/openai/kiro_openai_response.go @@ -0,0 +1,264 @@ +// Package openai provides response translation from Kiro to OpenAI format. +// This package handles the conversion of Kiro API responses into OpenAI Chat Completions-compatible +// JSON format, transforming streaming events and non-streaming responses. +package openai + +import ( + "encoding/json" + "fmt" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" + log "github.com/sirupsen/logrus" +) + +// functionCallIDCounter provides a process-wide unique counter for function call identifiers. +var functionCallIDCounter uint64 + +// BuildOpenAIResponse constructs an OpenAI Chat Completions-compatible response. +// Supports tool_calls when tools are present in the response. +// stopReason is passed from upstream; fallback logic applied if empty. +func BuildOpenAIResponse(content string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte { + // Build the message object + message := map[string]interface{}{ + "role": "assistant", + "content": content, + } + + // Add tool_calls if present + if len(toolUses) > 0 { + var toolCalls []map[string]interface{} + for i, tu := range toolUses { + inputJSON, _ := json.Marshal(tu.Input) + toolCalls = append(toolCalls, map[string]interface{}{ + "id": tu.ToolUseID, + "type": "function", + "index": i, + "function": map[string]interface{}{ + "name": tu.Name, + "arguments": string(inputJSON), + }, + }) + } + message["tool_calls"] = toolCalls + // When tool_calls are present, content should be null according to OpenAI spec + if content == "" { + message["content"] = nil + } + } + + // Use upstream stopReason; apply fallback logic if not provided + finishReason := mapKiroStopReasonToOpenAI(stopReason) + if finishReason == "" { + finishReason = "stop" + if len(toolUses) > 0 { + finishReason = "tool_calls" + } + log.Debugf("kiro-openai: buildOpenAIResponse using fallback finish_reason: %s", finishReason) + } + + response := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:24], + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{ + { + "index": 0, + "message": message, + "finish_reason": finishReason, + }, + }, + "usage": map[string]interface{}{ + "prompt_tokens": usageInfo.InputTokens, + "completion_tokens": usageInfo.OutputTokens, + "total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens, + }, + } + + result, _ := json.Marshal(response) + return result +} + +// mapKiroStopReasonToOpenAI converts Kiro/Claude stop_reason to OpenAI finish_reason +func mapKiroStopReasonToOpenAI(stopReason string) string { + switch stopReason { + case "end_turn": + return "stop" + case "stop_sequence": + return "stop" + case "tool_use": + return "tool_calls" + case "max_tokens": + return "length" + case "content_filtered": + return "content_filter" + default: + return stopReason + } +} + +// BuildOpenAIStreamChunk constructs an OpenAI Chat Completions streaming chunk. +// This is the delta format used in streaming responses. +func BuildOpenAIStreamChunk(model string, deltaContent string, deltaToolCalls []map[string]interface{}, finishReason string, index int) []byte { + delta := map[string]interface{}{} + + // First chunk should include role + if index == 0 && deltaContent == "" && len(deltaToolCalls) == 0 { + delta["role"] = "assistant" + delta["content"] = "" + } else if deltaContent != "" { + delta["content"] = deltaContent + } + + // Add tool_calls delta if present + if len(deltaToolCalls) > 0 { + delta["tool_calls"] = deltaToolCalls + } + + choice := map[string]interface{}{ + "index": 0, + "delta": delta, + } + + if finishReason != "" { + choice["finish_reason"] = finishReason + } else { + choice["finish_reason"] = nil + } + + chunk := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:12], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{choice}, + } + + result, _ := json.Marshal(chunk) + return result +} + +// BuildOpenAIStreamChunkWithToolCallStart creates a stream chunk for tool call start +func BuildOpenAIStreamChunkWithToolCallStart(model string, toolUseID, toolName string, toolIndex int) []byte { + toolCall := map[string]interface{}{ + "index": toolIndex, + "id": toolUseID, + "type": "function", + "function": map[string]interface{}{ + "name": toolName, + "arguments": "", + }, + } + + delta := map[string]interface{}{ + "tool_calls": []map[string]interface{}{toolCall}, + } + + choice := map[string]interface{}{ + "index": 0, + "delta": delta, + "finish_reason": nil, + } + + chunk := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:12], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{choice}, + } + + result, _ := json.Marshal(chunk) + return result +} + +// BuildOpenAIStreamChunkWithToolCallDelta creates a stream chunk for tool call arguments delta +func BuildOpenAIStreamChunkWithToolCallDelta(model string, argumentsDelta string, toolIndex int) []byte { + toolCall := map[string]interface{}{ + "index": toolIndex, + "function": map[string]interface{}{ + "arguments": argumentsDelta, + }, + } + + delta := map[string]interface{}{ + "tool_calls": []map[string]interface{}{toolCall}, + } + + choice := map[string]interface{}{ + "index": 0, + "delta": delta, + "finish_reason": nil, + } + + chunk := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:12], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{choice}, + } + + result, _ := json.Marshal(chunk) + return result +} + +// BuildOpenAIStreamDoneChunk creates the final [DONE] stream event +func BuildOpenAIStreamDoneChunk() []byte { + return []byte("data: [DONE]") +} + +// BuildOpenAIStreamFinishChunk creates the final chunk with finish_reason +func BuildOpenAIStreamFinishChunk(model string, finishReason string) []byte { + choice := map[string]interface{}{ + "index": 0, + "delta": map[string]interface{}{}, + "finish_reason": finishReason, + } + + chunk := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:12], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{choice}, + } + + result, _ := json.Marshal(chunk) + return result +} + +// BuildOpenAIStreamUsageChunk creates a chunk with usage information (optional, for stream_options.include_usage) +func BuildOpenAIStreamUsageChunk(model string, usageInfo usage.Detail) []byte { + chunk := map[string]interface{}{ + "id": "chatcmpl-" + uuid.New().String()[:12], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]interface{}{}, + "usage": map[string]interface{}{ + "prompt_tokens": usageInfo.InputTokens, + "completion_tokens": usageInfo.OutputTokens, + "total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens, + }, + } + + result, _ := json.Marshal(chunk) + return result +} + +// GenerateToolCallID generates a unique tool call ID in OpenAI format +func GenerateToolCallID(toolName string) string { + return fmt.Sprintf("call_%s_%d_%d", toolName[:min(8, len(toolName))], time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)) +} + +// min returns the minimum of two integers +func min(a, b int) int { + if a < b { + return a + } + return b +} \ No newline at end of file diff --git a/internal/translator/kiro/openai/kiro_openai_stream.go b/internal/translator/kiro/openai/kiro_openai_stream.go new file mode 100644 index 00000000..e72d970e --- /dev/null +++ b/internal/translator/kiro/openai/kiro_openai_stream.go @@ -0,0 +1,212 @@ +// Package openai provides streaming SSE event building for OpenAI format. +// This package handles the construction of OpenAI-compatible Server-Sent Events (SSE) +// for streaming responses from Kiro API. +package openai + +import ( + "encoding/json" + "time" + + "github.com/google/uuid" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" +) + +// OpenAIStreamState tracks the state of streaming response conversion +type OpenAIStreamState struct { + ChunkIndex int + ToolCallIndex int + HasSentFirstChunk bool + Model string + ResponseID string + Created int64 +} + +// NewOpenAIStreamState creates a new stream state for tracking +func NewOpenAIStreamState(model string) *OpenAIStreamState { + return &OpenAIStreamState{ + ChunkIndex: 0, + ToolCallIndex: 0, + HasSentFirstChunk: false, + Model: model, + ResponseID: "chatcmpl-" + uuid.New().String()[:24], + Created: time.Now().Unix(), + } +} + +// FormatSSEEvent formats a JSON payload for SSE streaming. +// Note: This returns raw JSON data without "data:" prefix. +// The SSE "data:" prefix is added by the Handler layer (e.g., openai_handlers.go) +// to maintain architectural consistency and avoid double-prefix issues. +func FormatSSEEvent(data []byte) string { + return string(data) +} + +// BuildOpenAISSETextDelta creates an SSE event for text content delta +func BuildOpenAISSETextDelta(state *OpenAIStreamState, textDelta string) string { + delta := map[string]interface{}{ + "content": textDelta, + } + + // Include role in first chunk + if !state.HasSentFirstChunk { + delta["role"] = "assistant" + state.HasSentFirstChunk = true + } + + chunk := buildBaseChunk(state, delta, nil) + result, _ := json.Marshal(chunk) + state.ChunkIndex++ + return FormatSSEEvent(result) +} + +// BuildOpenAISSEToolCallStart creates an SSE event for tool call start +func BuildOpenAISSEToolCallStart(state *OpenAIStreamState, toolUseID, toolName string) string { + toolCall := map[string]interface{}{ + "index": state.ToolCallIndex, + "id": toolUseID, + "type": "function", + "function": map[string]interface{}{ + "name": toolName, + "arguments": "", + }, + } + + delta := map[string]interface{}{ + "tool_calls": []map[string]interface{}{toolCall}, + } + + // Include role in first chunk if not sent yet + if !state.HasSentFirstChunk { + delta["role"] = "assistant" + state.HasSentFirstChunk = true + } + + chunk := buildBaseChunk(state, delta, nil) + result, _ := json.Marshal(chunk) + state.ChunkIndex++ + return FormatSSEEvent(result) +} + +// BuildOpenAISSEToolCallArgumentsDelta creates an SSE event for tool call arguments delta +func BuildOpenAISSEToolCallArgumentsDelta(state *OpenAIStreamState, argumentsDelta string, toolIndex int) string { + toolCall := map[string]interface{}{ + "index": toolIndex, + "function": map[string]interface{}{ + "arguments": argumentsDelta, + }, + } + + delta := map[string]interface{}{ + "tool_calls": []map[string]interface{}{toolCall}, + } + + chunk := buildBaseChunk(state, delta, nil) + result, _ := json.Marshal(chunk) + state.ChunkIndex++ + return FormatSSEEvent(result) +} + +// BuildOpenAISSEFinish creates an SSE event with finish_reason +func BuildOpenAISSEFinish(state *OpenAIStreamState, finishReason string) string { + chunk := buildBaseChunk(state, map[string]interface{}{}, &finishReason) + result, _ := json.Marshal(chunk) + state.ChunkIndex++ + return FormatSSEEvent(result) +} + +// BuildOpenAISSEUsage creates an SSE event with usage information +func BuildOpenAISSEUsage(state *OpenAIStreamState, usageInfo usage.Detail) string { + chunk := map[string]interface{}{ + "id": state.ResponseID, + "object": "chat.completion.chunk", + "created": state.Created, + "model": state.Model, + "choices": []map[string]interface{}{}, + "usage": map[string]interface{}{ + "prompt_tokens": usageInfo.InputTokens, + "completion_tokens": usageInfo.OutputTokens, + "total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens, + }, + } + result, _ := json.Marshal(chunk) + return FormatSSEEvent(result) +} + +// BuildOpenAISSEDone creates the final [DONE] SSE event. +// Note: This returns raw "[DONE]" without "data:" prefix. +// The SSE "data:" prefix is added by the Handler layer (e.g., openai_handlers.go) +// to maintain architectural consistency and avoid double-prefix issues. +func BuildOpenAISSEDone() string { + return "[DONE]" +} + +// buildBaseChunk creates a base chunk structure for streaming +func buildBaseChunk(state *OpenAIStreamState, delta map[string]interface{}, finishReason *string) map[string]interface{} { + choice := map[string]interface{}{ + "index": 0, + "delta": delta, + } + + if finishReason != nil { + choice["finish_reason"] = *finishReason + } else { + choice["finish_reason"] = nil + } + + return map[string]interface{}{ + "id": state.ResponseID, + "object": "chat.completion.chunk", + "created": state.Created, + "model": state.Model, + "choices": []map[string]interface{}{choice}, + } +} + +// BuildOpenAISSEReasoningDelta creates an SSE event for reasoning content delta +// This is used for o1/o3 style models that expose reasoning tokens +func BuildOpenAISSEReasoningDelta(state *OpenAIStreamState, reasoningDelta string) string { + delta := map[string]interface{}{ + "reasoning_content": reasoningDelta, + } + + // Include role in first chunk + if !state.HasSentFirstChunk { + delta["role"] = "assistant" + state.HasSentFirstChunk = true + } + + chunk := buildBaseChunk(state, delta, nil) + result, _ := json.Marshal(chunk) + state.ChunkIndex++ + return FormatSSEEvent(result) +} + +// BuildOpenAISSEFirstChunk creates the first chunk with role only +func BuildOpenAISSEFirstChunk(state *OpenAIStreamState) string { + delta := map[string]interface{}{ + "role": "assistant", + "content": "", + } + + state.HasSentFirstChunk = true + chunk := buildBaseChunk(state, delta, nil) + result, _ := json.Marshal(chunk) + state.ChunkIndex++ + return FormatSSEEvent(result) +} + +// ThinkingTagState tracks state for thinking tag detection in streaming +type ThinkingTagState struct { + InThinkingBlock bool + PendingStartChars int + PendingEndChars int +} + +// NewThinkingTagState creates a new thinking tag state +func NewThinkingTagState() *ThinkingTagState { + return &ThinkingTagState{ + InThinkingBlock: false, + PendingStartChars: 0, + PendingEndChars: 0, + } +} \ No newline at end of file diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 43a3a3dc..7c2976ee 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -22,6 +22,7 @@ import ( "github.com/fsnotify/fsnotify" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" "gopkg.in/yaml.v3" @@ -179,6 +180,9 @@ func (w *Watcher) Start(ctx context.Context) error { } log.Debugf("watching auth directory: %s", w.authDir) + // Watch Kiro IDE token file directory for automatic token updates + w.watchKiroIDETokenFile() + // Start the event processing goroutine go w.processEvents(ctx) @@ -187,6 +191,31 @@ func (w *Watcher) Start(ctx context.Context) error { return nil } +// watchKiroIDETokenFile adds the Kiro IDE token file directory to the watcher. +// This enables automatic detection of token updates from Kiro IDE. +func (w *Watcher) watchKiroIDETokenFile() { + homeDir, err := os.UserHomeDir() + if err != nil { + log.Debugf("failed to get home directory for Kiro IDE token watch: %v", err) + return + } + + // Kiro IDE stores tokens in ~/.aws/sso/cache/ + kiroTokenDir := filepath.Join(homeDir, ".aws", "sso", "cache") + + // Check if directory exists + if _, statErr := os.Stat(kiroTokenDir); os.IsNotExist(statErr) { + log.Debugf("Kiro IDE token directory does not exist: %s", kiroTokenDir) + return + } + + if errAdd := w.watcher.Add(kiroTokenDir); errAdd != nil { + log.Debugf("failed to watch Kiro IDE token directory %s: %v", kiroTokenDir, errAdd) + return + } + log.Debugf("watching Kiro IDE token directory: %s", kiroTokenDir) +} + // Stop stops the file watcher func (w *Watcher) Stop() error { w.stopDispatch() @@ -791,11 +820,21 @@ func (w *Watcher) handleEvent(event fsnotify.Event) { normalizedAuthDir := w.normalizeAuthPath(w.authDir) isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0 authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename - isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0 - if !isConfigEvent && !isAuthJSON { + isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0 + + // Check for Kiro IDE token file changes + isKiroIDEToken := w.isKiroIDETokenFile(event.Name) && event.Op&authOps != 0 + + if !isConfigEvent && !isAuthJSON && !isKiroIDEToken { // Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise. return } + + // Handle Kiro IDE token file changes + if isKiroIDEToken { + w.handleKiroIDETokenChange(event) + return + } now := time.Now() log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name) @@ -857,6 +896,51 @@ func (w *Watcher) scheduleConfigReload() { }) } +// isKiroIDETokenFile checks if the given path is the Kiro IDE token file. +func (w *Watcher) isKiroIDETokenFile(path string) bool { + // Check if it's the kiro-auth-token.json file in ~/.aws/sso/cache/ + // Use filepath.ToSlash to ensure consistent separators across platforms (Windows uses backslashes) + normalized := filepath.ToSlash(path) + return strings.HasSuffix(normalized, "kiro-auth-token.json") && strings.Contains(normalized, ".aws/sso/cache") +} + +// handleKiroIDETokenChange processes changes to the Kiro IDE token file. +// When the token file is updated by Kiro IDE, this triggers a reload of Kiro auth. +func (w *Watcher) handleKiroIDETokenChange(event fsnotify.Event) { + log.Debugf("Kiro IDE token file event detected: %s %s", event.Op.String(), event.Name) + + if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 { + // Token file removed - wait briefly for potential atomic replace + time.Sleep(replaceCheckDelay) + if _, statErr := os.Stat(event.Name); statErr != nil { + log.Debugf("Kiro IDE token file removed: %s", event.Name) + return + } + } + + // Try to load the updated token + tokenData, err := kiroauth.LoadKiroIDEToken() + if err != nil { + log.Debugf("failed to load Kiro IDE token after change: %v", err) + return + } + + log.Infof("Kiro IDE token file updated, access token refreshed (provider: %s)", tokenData.Provider) + + // Trigger auth state refresh to pick up the new token + w.refreshAuthState() + + // Notify callback if set + w.clientsMutex.RLock() + cfg := w.config + w.clientsMutex.RUnlock() + + if w.reloadCallback != nil && cfg != nil { + log.Debugf("triggering server update callback after Kiro IDE token change") + w.reloadCallback(cfg) + } +} + func (w *Watcher) reloadConfigIfChanged() { data, err := os.ReadFile(w.configPath) if err != nil { @@ -1236,6 +1320,88 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { applyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey") out = append(out, a) } + // Kiro (AWS CodeWhisperer) -> synthesize auths + var kAuth *kiroauth.KiroAuth + if len(cfg.KiroKey) > 0 { + kAuth = kiroauth.NewKiroAuth(cfg) + } + for i := range cfg.KiroKey { + kk := cfg.KiroKey[i] + var accessToken, profileArn, refreshToken string + + // Try to load from token file first + if kk.TokenFile != "" && kAuth != nil { + tokenData, err := kAuth.LoadTokenFromFile(kk.TokenFile) + if err != nil { + log.Warnf("failed to load kiro token file %s: %v", kk.TokenFile, err) + } else { + accessToken = tokenData.AccessToken + profileArn = tokenData.ProfileArn + refreshToken = tokenData.RefreshToken + } + } + + // Override with direct config values if provided + if kk.AccessToken != "" { + accessToken = kk.AccessToken + } + if kk.ProfileArn != "" { + profileArn = kk.ProfileArn + } + if kk.RefreshToken != "" { + refreshToken = kk.RefreshToken + } + + if accessToken == "" { + log.Warnf("kiro config[%d] missing access_token, skipping", i) + continue + } + + // profileArn is optional for AWS Builder ID users + id, token := idGen.next("kiro:token", accessToken, profileArn) + attrs := map[string]string{ + "source": fmt.Sprintf("config:kiro[%s]", token), + "access_token": accessToken, + } + if profileArn != "" { + attrs["profile_arn"] = profileArn + } + if kk.Region != "" { + attrs["region"] = kk.Region + } + if kk.AgentTaskType != "" { + attrs["agent_task_type"] = kk.AgentTaskType + } + if kk.PreferredEndpoint != "" { + attrs["preferred_endpoint"] = kk.PreferredEndpoint + } else if cfg.KiroPreferredEndpoint != "" { + // Apply global default if not overridden by specific key + attrs["preferred_endpoint"] = cfg.KiroPreferredEndpoint + } + if refreshToken != "" { + attrs["refresh_token"] = refreshToken + } + proxyURL := strings.TrimSpace(kk.ProxyURL) + a := &coreauth.Auth{ + ID: id, + Provider: "kiro", + Label: "kiro-token", + Status: coreauth.StatusActive, + ProxyURL: proxyURL, + Attributes: attrs, + CreatedAt: now, + UpdatedAt: now, + } + + if refreshToken != "" { + if a.Metadata == nil { + a.Metadata = make(map[string]any) + } + a.Metadata["refresh_token"] = refreshToken + } + + out = append(out, a) + } for i := range cfg.OpenAICompatibility { compat := &cfg.OpenAICompatibility[i] providerName := strings.ToLower(strings.TrimSpace(compat.Name)) @@ -1342,7 +1508,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { } // Also synthesize auth entries directly from auth files (for OAuth/file-backed providers) - entries, _ := os.ReadDir(w.authDir) + log.Debugf("SnapshotCoreAuths: scanning auth directory: %s", w.authDir) + entries, readErr := os.ReadDir(w.authDir) + if readErr != nil { + log.Errorf("SnapshotCoreAuths: failed to read auth directory %s: %v", w.authDir, readErr) + } + log.Debugf("SnapshotCoreAuths: found %d entries in auth directory", len(entries)) for _, e := range entries { if e.IsDir() { continue @@ -1361,9 +1532,20 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { continue } t, _ := metadata["type"].(string) + + // Detect Kiro auth files by auth_method field (they don't have "type" field) if t == "" { + if authMethod, _ := metadata["auth_method"].(string); authMethod == "builder-id" || authMethod == "social" { + t = "kiro" + log.Debugf("SnapshotCoreAuths: detected Kiro auth by auth_method: %s", name) + } + } + + if t == "" { + log.Debugf("SnapshotCoreAuths: skipping file without type: %s", name) continue } + log.Debugf("SnapshotCoreAuths: processing auth file: %s (type=%s)", name, t) provider := strings.ToLower(t) if provider == "gemini" { provider = "gemini-cli" @@ -1372,6 +1554,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if email, _ := metadata["email"].(string); email != "" { label = email } + // For Kiro, use provider field as label if available + if provider == "kiro" { + if kiroProvider, _ := metadata["provider"].(string); kiroProvider != "" { + label = fmt.Sprintf("kiro-%s", strings.ToLower(kiroProvider)) + } + } // Use relative path under authDir as ID to stay consistent with the file-based token store id := full if rel, errRel := filepath.Rel(w.authDir, full); errRel == nil && rel != "" { @@ -1397,6 +1585,27 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { CreatedAt: now, UpdatedAt: now, } + // Set NextRefreshAfter for Kiro auth based on expires_at + if provider == "kiro" { + if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" { + if expiresAt, parseErr := time.Parse(time.RFC3339, expiresAtStr); parseErr == nil { + // Refresh 30 minutes before expiry + a.NextRefreshAfter = expiresAt.Add(-30 * time.Minute) + } + } + + // Apply global preferred endpoint setting if not present in metadata + if cfg.KiroPreferredEndpoint != "" { + // Check if already set in metadata (which takes precedence in executor) + if _, hasMeta := metadata["preferred_endpoint"]; !hasMeta { + if a.Attributes == nil { + a.Attributes = make(map[string]string) + } + a.Attributes["preferred_endpoint"] = cfg.KiroPreferredEndpoint + } + } + } + applyAuthExcludedModelsMeta(a, cfg, nil, "oauth") if provider == "gemini-cli" { if virtuals := synthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 { diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index a17e54aa..9ad6de81 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" "strings" + "sync" "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" @@ -51,7 +52,25 @@ type BaseAPIHandler struct { Cfg *config.SDKConfig // OpenAICompatProviders is a list of provider names for OpenAI compatibility. - OpenAICompatProviders []string + openAICompatProviders []string + openAICompatMutex sync.RWMutex +} + +// GetOpenAICompatProviders safely returns a copy of the provider names +func (h *BaseAPIHandler) GetOpenAICompatProviders() []string { + h.openAICompatMutex.RLock() + defer h.openAICompatMutex.RUnlock() + result := make([]string, len(h.openAICompatProviders)) + copy(result, h.openAICompatProviders) + return result +} + +// SetOpenAICompatProviders safely sets the provider names +func (h *BaseAPIHandler) SetOpenAICompatProviders(providers []string) { + h.openAICompatMutex.Lock() + defer h.openAICompatMutex.Unlock() + h.openAICompatProviders = make([]string, len(providers)) + copy(h.openAICompatProviders, providers) } // NewBaseAPIHandlers creates a new API handlers instance. @@ -64,11 +83,12 @@ type BaseAPIHandler struct { // Returns: // - *BaseAPIHandler: A new API handlers instance func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler { - return &BaseAPIHandler{ - Cfg: cfg, - AuthManager: authManager, - OpenAICompatProviders: openAICompatProviders, + h := &BaseAPIHandler{ + Cfg: cfg, + AuthManager: authManager, } + h.SetOpenAICompatProviders(openAICompatProviders) + return h } // UpdateClients updates the handlers' client list and configuration. @@ -398,7 +418,7 @@ func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, mode } // Check if the provider is a configured openai-compatibility provider - for _, pName := range h.OpenAICompatProviders { + for _, pName := range h.GetOpenAICompatProviders() { if pName == providerPart { return providerPart, modelPart, true } diff --git a/sdk/auth/github_copilot.go b/sdk/auth/github_copilot.go new file mode 100644 index 00000000..1d14ac47 --- /dev/null +++ b/sdk/auth/github_copilot.go @@ -0,0 +1,129 @@ +package auth + +import ( + "context" + "fmt" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot" + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +// GitHubCopilotAuthenticator implements the OAuth device flow login for GitHub Copilot. +type GitHubCopilotAuthenticator struct{} + +// NewGitHubCopilotAuthenticator constructs a new GitHub Copilot authenticator. +func NewGitHubCopilotAuthenticator() Authenticator { + return &GitHubCopilotAuthenticator{} +} + +// Provider returns the provider key for github-copilot. +func (GitHubCopilotAuthenticator) Provider() string { + return "github-copilot" +} + +// RefreshLead returns nil since GitHub OAuth tokens don't expire in the traditional sense. +// The token remains valid until the user revokes it or the Copilot subscription expires. +func (GitHubCopilotAuthenticator) RefreshLead() *time.Duration { + return nil +} + +// Login initiates the GitHub device flow authentication for Copilot access. +func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("cliproxy auth: configuration is required") + } + if opts == nil { + opts = &LoginOptions{} + } + + authSvc := copilot.NewCopilotAuth(cfg) + + // Start the device flow + fmt.Println("Starting GitHub Copilot authentication...") + deviceCode, err := authSvc.StartDeviceFlow(ctx) + if err != nil { + return nil, fmt.Errorf("github-copilot: failed to start device flow: %w", err) + } + + // Display the user code and verification URL + fmt.Printf("\nTo authenticate, please visit: %s\n", deviceCode.VerificationURI) + fmt.Printf("And enter the code: %s\n\n", deviceCode.UserCode) + + // Try to open the browser automatically + if !opts.NoBrowser { + if browser.IsAvailable() { + if errOpen := browser.OpenURL(deviceCode.VerificationURI); errOpen != nil { + log.Warnf("Failed to open browser automatically: %v", errOpen) + } + } + } + + fmt.Println("Waiting for GitHub authorization...") + fmt.Printf("(This will timeout in %d seconds if not authorized)\n", deviceCode.ExpiresIn) + + // Wait for user authorization + authBundle, err := authSvc.WaitForAuthorization(ctx, deviceCode) + if err != nil { + errMsg := copilot.GetUserFriendlyMessage(err) + return nil, fmt.Errorf("github-copilot: %s", errMsg) + } + + // Verify the token can get a Copilot API token + fmt.Println("Verifying Copilot access...") + apiToken, err := authSvc.GetCopilotAPIToken(ctx, authBundle.TokenData.AccessToken) + if err != nil { + return nil, fmt.Errorf("github-copilot: failed to verify Copilot access - you may not have an active Copilot subscription: %w", err) + } + + // Create the token storage + tokenStorage := authSvc.CreateTokenStorage(authBundle) + + // Build metadata with token information for the executor + metadata := map[string]any{ + "type": "github-copilot", + "username": authBundle.Username, + "access_token": authBundle.TokenData.AccessToken, + "token_type": authBundle.TokenData.TokenType, + "scope": authBundle.TokenData.Scope, + "timestamp": time.Now().UnixMilli(), + } + + if apiToken.ExpiresAt > 0 { + metadata["api_token_expires_at"] = apiToken.ExpiresAt + } + + fileName := fmt.Sprintf("github-copilot-%s.json", authBundle.Username) + + fmt.Printf("\nGitHub Copilot authentication successful for user: %s\n", authBundle.Username) + + return &coreauth.Auth{ + ID: fileName, + Provider: a.Provider(), + FileName: fileName, + Label: authBundle.Username, + Storage: tokenStorage, + Metadata: metadata, + }, nil +} + +// RefreshGitHubCopilotToken validates and returns the current token status. +// GitHub OAuth tokens don't need traditional refresh - we just validate they still work. +func RefreshGitHubCopilotToken(ctx context.Context, cfg *config.Config, storage *copilot.CopilotTokenStorage) error { + if storage == nil || storage.AccessToken == "" { + return fmt.Errorf("no token available") + } + + authSvc := copilot.NewCopilotAuth(cfg) + + // Validate the token can still get a Copilot API token + _, err := authSvc.GetCopilotAPIToken(ctx, storage.AccessToken) + if err != nil { + return fmt.Errorf("token validation failed: %w", err) + } + + return nil +} diff --git a/sdk/auth/kiro.go b/sdk/auth/kiro.go new file mode 100644 index 00000000..1eed4b94 --- /dev/null +++ b/sdk/auth/kiro.go @@ -0,0 +1,363 @@ +package auth + +import ( + "context" + "fmt" + "strings" + "time" + + kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +// extractKiroIdentifier extracts a meaningful identifier for file naming. +// Returns account name if provided, otherwise profile ARN ID. +// All extracted values are sanitized to prevent path injection attacks. +func extractKiroIdentifier(accountName, profileArn string) string { + // Priority 1: Use account name if provided + if accountName != "" { + return kiroauth.SanitizeEmailForFilename(accountName) + } + + // Priority 2: Use profile ARN ID part (sanitized to prevent path injection) + if profileArn != "" { + parts := strings.Split(profileArn, "/") + if len(parts) >= 2 { + // Sanitize the ARN component to prevent path traversal + return kiroauth.SanitizeEmailForFilename(parts[len(parts)-1]) + } + } + + // Fallback: timestamp + return fmt.Sprintf("%d", time.Now().UnixNano()%100000) +} + +// KiroAuthenticator implements OAuth authentication for Kiro with Google login. +type KiroAuthenticator struct{} + +// NewKiroAuthenticator constructs a Kiro authenticator. +func NewKiroAuthenticator() *KiroAuthenticator { + return &KiroAuthenticator{} +} + +// Provider returns the provider key for the authenticator. +func (a *KiroAuthenticator) Provider() string { + return "kiro" +} + +// RefreshLead indicates how soon before expiry a refresh should be attempted. +// Set to 5 minutes to match Antigravity and avoid frequent refresh checks while still ensuring timely token refresh. +func (a *KiroAuthenticator) RefreshLead() *time.Duration { + d := 5 * time.Minute + return &d +} + +// Login performs OAuth login for Kiro with AWS Builder ID. +func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("kiro auth: configuration is required") + } + + oauth := kiroauth.NewKiroOAuth(cfg) + + // Use AWS Builder ID device code flow + tokenData, err := oauth.LoginWithBuilderID(ctx) + if err != nil { + return nil, fmt.Errorf("login failed: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + + now := time.Now() + fileName := fmt.Sprintf("kiro-aws-%s.json", idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: "kiro-aws", + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "client_id": tokenData.ClientID, + "client_secret": tokenData.ClientSecret, + "email": tokenData.Email, + }, + Attributes: map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": "aws-builder-id", + "email": tokenData.Email, + }, + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), + } + + if tokenData.Email != "" { + fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email) + } else { + fmt.Println("\n✓ Kiro authentication completed successfully!") + } + + return record, nil +} + +// LoginWithGoogle performs OAuth login for Kiro with Google. +// This uses a custom protocol handler (kiro://) to receive the callback. +func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("kiro auth: configuration is required") + } + + oauth := kiroauth.NewKiroOAuth(cfg) + + // Use Google OAuth flow with protocol handler + tokenData, err := oauth.LoginWithGoogle(ctx) + if err != nil { + return nil, fmt.Errorf("google login failed: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + + now := time.Now() + fileName := fmt.Sprintf("kiro-google-%s.json", idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: "kiro-google", + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "email": tokenData.Email, + }, + Attributes: map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": "google-oauth", + "email": tokenData.Email, + }, + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), + } + + if tokenData.Email != "" { + fmt.Printf("\n✓ Kiro Google authentication completed successfully! (Account: %s)\n", tokenData.Email) + } else { + fmt.Println("\n✓ Kiro Google authentication completed successfully!") + } + + return record, nil +} + +// LoginWithGitHub performs OAuth login for Kiro with GitHub. +// This uses a custom protocol handler (kiro://) to receive the callback. +func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if cfg == nil { + return nil, fmt.Errorf("kiro auth: configuration is required") + } + + oauth := kiroauth.NewKiroOAuth(cfg) + + // Use GitHub OAuth flow with protocol handler + tokenData, err := oauth.LoginWithGitHub(ctx) + if err != nil { + return nil, fmt.Errorf("github login failed: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + + now := time.Now() + fileName := fmt.Sprintf("kiro-github-%s.json", idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: "kiro-github", + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "email": tokenData.Email, + }, + Attributes: map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": "github-oauth", + "email": tokenData.Email, + }, + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), + } + + if tokenData.Email != "" { + fmt.Printf("\n✓ Kiro GitHub authentication completed successfully! (Account: %s)\n", tokenData.Email) + } else { + fmt.Println("\n✓ Kiro GitHub authentication completed successfully!") + } + + return record, nil +} + +// ImportFromKiroIDE imports token from Kiro IDE's token file. +func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.Config) (*coreauth.Auth, error) { + tokenData, err := kiroauth.LoadKiroIDEToken() + if err != nil { + return nil, fmt.Errorf("failed to load Kiro IDE token: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Extract email from JWT if not already set (for imported tokens) + if tokenData.Email == "" { + tokenData.Email = kiroauth.ExtractEmailFromJWT(tokenData.AccessToken) + } + + // Extract identifier for file naming + idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn) + // Sanitize provider to prevent path traversal (defense-in-depth) + provider := kiroauth.SanitizeEmailForFilename(strings.ToLower(strings.TrimSpace(tokenData.Provider))) + if provider == "" { + provider = "imported" // Fallback for legacy tokens without provider + } + + now := time.Now() + fileName := fmt.Sprintf("kiro-%s-%s.json", provider, idPart) + + record := &coreauth.Auth{ + ID: fileName, + Provider: "kiro", + FileName: fileName, + Label: fmt.Sprintf("kiro-%s", provider), + Status: coreauth.StatusActive, + CreatedAt: now, + UpdatedAt: now, + Metadata: map[string]any{ + "type": "kiro", + "access_token": tokenData.AccessToken, + "refresh_token": tokenData.RefreshToken, + "profile_arn": tokenData.ProfileArn, + "expires_at": tokenData.ExpiresAt, + "auth_method": tokenData.AuthMethod, + "provider": tokenData.Provider, + "email": tokenData.Email, + }, + Attributes: map[string]string{ + "profile_arn": tokenData.ProfileArn, + "source": "kiro-ide-import", + "email": tokenData.Email, + }, + // NextRefreshAfter is aligned with RefreshLead (5min) + NextRefreshAfter: expiresAt.Add(-5 * time.Minute), + } + + // Display the email if extracted + if tokenData.Email != "" { + fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s, Account: %s)\n", tokenData.Provider, tokenData.Email) + } else { + fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s)\n", tokenData.Provider) + } + + return record, nil +} + +// Refresh refreshes an expired Kiro token using AWS SSO OIDC. +func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, auth *coreauth.Auth) (*coreauth.Auth, error) { + if auth == nil || auth.Metadata == nil { + return nil, fmt.Errorf("invalid auth record") + } + + refreshToken, ok := auth.Metadata["refresh_token"].(string) + if !ok || refreshToken == "" { + return nil, fmt.Errorf("refresh token not found") + } + + clientID, _ := auth.Metadata["client_id"].(string) + clientSecret, _ := auth.Metadata["client_secret"].(string) + authMethod, _ := auth.Metadata["auth_method"].(string) + + var tokenData *kiroauth.KiroTokenData + var err error + + // Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint + if clientID != "" && clientSecret != "" && authMethod == "builder-id" { + ssoClient := kiroauth.NewSSOOIDCClient(cfg) + tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken) + } else { + // Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub) + oauth := kiroauth.NewKiroOAuth(cfg) + tokenData, err = oauth.RefreshToken(ctx, refreshToken) + } + + if err != nil { + return nil, fmt.Errorf("token refresh failed: %w", err) + } + + // Parse expires_at + expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt) + if err != nil { + expiresAt = time.Now().Add(1 * time.Hour) + } + + // Clone auth to avoid mutating the input parameter + updated := auth.Clone() + now := time.Now() + updated.UpdatedAt = now + updated.LastRefreshedAt = now + updated.Metadata["access_token"] = tokenData.AccessToken + updated.Metadata["refresh_token"] = tokenData.RefreshToken + updated.Metadata["expires_at"] = tokenData.ExpiresAt + updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization + // NextRefreshAfter is aligned with RefreshLead (5min) + updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute) + + return updated, nil +} diff --git a/sdk/auth/manager.go b/sdk/auth/manager.go index c6469a7d..d630f128 100644 --- a/sdk/auth/manager.go +++ b/sdk/auth/manager.go @@ -74,3 +74,16 @@ func (m *Manager) Login(ctx context.Context, provider string, cfg *config.Config } return record, savedPath, nil } + +// SaveAuth persists an auth record directly without going through the login flow. +func (m *Manager) SaveAuth(record *coreauth.Auth, cfg *config.Config) (string, error) { + if m.store == nil { + return "", fmt.Errorf("no store configured") + } + if cfg != nil { + if dirSetter, ok := m.store.(interface{ SetBaseDir(string) }); ok { + dirSetter.SetBaseDir(cfg.AuthDir) + } + } + return m.store.Save(context.Background(), record) +} diff --git a/sdk/auth/refresh_registry.go b/sdk/auth/refresh_registry.go index e82ac684..c51712a2 100644 --- a/sdk/auth/refresh_registry.go +++ b/sdk/auth/refresh_registry.go @@ -14,6 +14,8 @@ func init() { registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() }) registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() }) registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() }) + registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() }) + registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() }) } func registerRefreshLead(provider string, factory func() Authenticator) { diff --git a/sdk/cliproxy/auth/manager.go b/sdk/cliproxy/auth/manager.go index 9f247bb9..a4a1ef36 100644 --- a/sdk/cliproxy/auth/manager.go +++ b/sdk/cliproxy/auth/manager.go @@ -40,7 +40,7 @@ type RefreshEvaluator interface { const ( refreshCheckInterval = 5 * time.Second refreshPendingBackoff = time.Minute - refreshFailureBackoff = 5 * time.Minute + refreshFailureBackoff = 1 * time.Minute quotaBackoffBase = time.Second quotaBackoffMax = 30 * time.Minute ) @@ -1498,7 +1498,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) { updated.Runtime = auth.Runtime } updated.LastRefreshedAt = now - updated.NextRefreshAfter = time.Time{} + // Preserve NextRefreshAfter set by the Authenticator + // If the Authenticator set a reasonable refresh time, it should not be overwritten + // If the Authenticator did not set it (zero value), shouldRefresh will use default logic updated.LastError = nil updated.UpdatedAt = now _, _ = m.Update(ctx, updated) diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 1ef829d1..251457a6 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -379,6 +379,10 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg)) case "iflow": s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg)) + case "kiro": + s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg)) + case "github-copilot": + s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg)) default: providerKey := strings.ToLower(strings.TrimSpace(a.Provider)) if providerKey == "" { @@ -720,6 +724,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { models = applyExcludedModels(models, excluded) case "iflow": models = registry.GetIFlowModels() + case "github-copilot": + models = registry.GetGitHubCopilotModels() + models = applyExcludedModels(models, excluded) + case "kiro": + models = registry.GetKiroModels() models = applyExcludedModels(models, excluded) default: // Handle OpenAI-compatibility providers by name using config