mirror of
https://github.com/router-for-me/CLIProxyAPIPlus.git
synced 2026-03-21 16:40:22 +00:00
Merge branch 'main' into plus
This commit is contained in:
2
.github/workflows/docker-image.yml
vendored
2
.github/workflows/docker-image.yml
vendored
@@ -7,7 +7,7 @@ on:
|
|||||||
|
|
||||||
env:
|
env:
|
||||||
APP_NAME: CLIProxyAPI
|
APP_NAME: CLIProxyAPI
|
||||||
DOCKERHUB_REPO: eceasy/cli-proxy-api
|
DOCKERHUB_REPO: eceasy/cli-proxy-api-plus
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
docker:
|
docker:
|
||||||
|
|||||||
3
.github/workflows/release.yaml
vendored
3
.github/workflows/release.yaml
vendored
@@ -23,7 +23,8 @@ jobs:
|
|||||||
cache: true
|
cache: true
|
||||||
- name: Generate Build Metadata
|
- name: Generate Build Metadata
|
||||||
run: |
|
run: |
|
||||||
echo VERSION=`git describe --tags --always --dirty` >> $GITHUB_ENV
|
VERSION=$(git describe --tags --always --dirty)
|
||||||
|
echo "VERSION=${VERSION}" >> $GITHUB_ENV
|
||||||
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
|
||||||
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
|
||||||
- uses: goreleaser/goreleaser-action@v4
|
- uses: goreleaser/goreleaser-action@v4
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,5 +1,6 @@
|
|||||||
# Binaries
|
# Binaries
|
||||||
cli-proxy-api
|
cli-proxy-api
|
||||||
|
cliproxy
|
||||||
*.exe
|
*.exe
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
@@ -32,6 +33,7 @@ GEMINI.md
|
|||||||
.claude/*
|
.claude/*
|
||||||
.serena/*
|
.serena/*
|
||||||
.bmad/*
|
.bmad/*
|
||||||
|
.mcp/cache/
|
||||||
|
|
||||||
# macOS
|
# macOS
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
builds:
|
builds:
|
||||||
- id: "cli-proxy-api"
|
- id: "cli-proxy-api-plus"
|
||||||
env:
|
env:
|
||||||
- CGO_ENABLED=0
|
- CGO_ENABLED=0
|
||||||
goos:
|
goos:
|
||||||
@@ -10,11 +10,11 @@ builds:
|
|||||||
- amd64
|
- amd64
|
||||||
- arm64
|
- arm64
|
||||||
main: ./cmd/server/
|
main: ./cmd/server/
|
||||||
binary: cli-proxy-api
|
binary: cli-proxy-api-plus
|
||||||
ldflags:
|
ldflags:
|
||||||
- -s -w -X 'main.Version={{.Version}}' -X 'main.Commit={{.ShortCommit}}' -X 'main.BuildDate={{.Date}}'
|
- -s -w -X 'main.Version={{.Version}}-plus' -X 'main.Commit={{.ShortCommit}}' -X 'main.BuildDate={{.Date}}'
|
||||||
archives:
|
archives:
|
||||||
- id: "cli-proxy-api"
|
- id: "cli-proxy-api-plus"
|
||||||
format: tar.gz
|
format: tar.gz
|
||||||
format_overrides:
|
format_overrides:
|
||||||
- goos: windows
|
- goos: windows
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ ARG VERSION=dev
|
|||||||
ARG COMMIT=none
|
ARG COMMIT=none
|
||||||
ARG BUILD_DATE=unknown
|
ARG BUILD_DATE=unknown
|
||||||
|
|
||||||
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X 'main.Version=${VERSION}' -X 'main.Commit=${COMMIT}' -X 'main.BuildDate=${BUILD_DATE}'" -o ./CLIProxyAPI ./cmd/server/
|
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X 'main.Version=${VERSION}-plus' -X 'main.Commit=${COMMIT}' -X 'main.BuildDate=${BUILD_DATE}'" -o ./CLIProxyAPIPlus ./cmd/server/
|
||||||
|
|
||||||
FROM alpine:3.22.0
|
FROM alpine:3.22.0
|
||||||
|
|
||||||
@@ -20,7 +20,7 @@ RUN apk add --no-cache tzdata
|
|||||||
|
|
||||||
RUN mkdir /CLIProxyAPI
|
RUN mkdir /CLIProxyAPI
|
||||||
|
|
||||||
COPY --from=builder ./app/CLIProxyAPI /CLIProxyAPI/CLIProxyAPI
|
COPY --from=builder ./app/CLIProxyAPIPlus /CLIProxyAPI/CLIProxyAPIPlus
|
||||||
|
|
||||||
COPY config.example.yaml /CLIProxyAPI/config.example.yaml
|
COPY config.example.yaml /CLIProxyAPI/config.example.yaml
|
||||||
|
|
||||||
@@ -32,4 +32,4 @@ ENV TZ=Asia/Shanghai
|
|||||||
|
|
||||||
RUN cp /usr/share/zoneinfo/${TZ} /etc/localtime && echo "${TZ}" > /etc/timezone
|
RUN cp /usr/share/zoneinfo/${TZ} /etc/localtime && echo "${TZ}" > /etc/timezone
|
||||||
|
|
||||||
CMD ["./CLIProxyAPI"]
|
CMD ["./CLIProxyAPIPlus"]
|
||||||
103
README.md
103
README.md
@@ -1,106 +1,23 @@
|
|||||||
# CLI Proxy API
|
# CLIProxyAPI Plus
|
||||||
|
|
||||||
English | [中文](README_CN.md)
|
English | [Chinese](README_CN.md)
|
||||||
|
|
||||||
A proxy server that provides OpenAI/Gemini/Claude/Codex compatible API interfaces for CLI.
|
This is the Plus version of [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI), adding support for third-party providers on top of the mainline project.
|
||||||
|
|
||||||
It now also supports OpenAI Codex (GPT models) and Claude Code via OAuth.
|
All third-party provider support is maintained by community contributors; CLIProxyAPI does not provide technical support. Please contact the corresponding community maintainer if you need assistance.
|
||||||
|
|
||||||
So you can use local or multi-account CLI access with OpenAI(include Responses)/Gemini/Claude-compatible clients and SDKs.
|
The Plus release stays in lockstep with the mainline features.
|
||||||
|
|
||||||
## Sponsor
|
## Differences from the Mainline
|
||||||
|
|
||||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
- Added GitHub Copilot support (OAuth login), provided by [em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)
|
||||||
|
- Added Kiro (AWS CodeWhisperer) support (OAuth login), provided by [fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration), [Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)
|
||||||
This project is sponsored by Z.ai, supporting us with their GLM CODING PLAN.
|
|
||||||
|
|
||||||
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $3/month. It provides access to their flagship GLM-4.6 model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
|
|
||||||
|
|
||||||
Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
- OpenAI/Gemini/Claude compatible API endpoints for CLI models
|
|
||||||
- OpenAI Codex support (GPT models) via OAuth login
|
|
||||||
- Claude Code support via OAuth login
|
|
||||||
- Qwen Code support via OAuth login
|
|
||||||
- iFlow support via OAuth login
|
|
||||||
- Amp CLI and IDE extensions support with provider routing
|
|
||||||
- Streaming and non-streaming responses
|
|
||||||
- Function calling/tools support
|
|
||||||
- Multimodal input support (text and images)
|
|
||||||
- Multiple accounts with round-robin load balancing (Gemini, OpenAI, Claude, Qwen and iFlow)
|
|
||||||
- Simple CLI authentication flows (Gemini, OpenAI, Claude, Qwen and iFlow)
|
|
||||||
- Generative Language API Key support
|
|
||||||
- AI Studio Build multi-account load balancing
|
|
||||||
- Gemini CLI multi-account load balancing
|
|
||||||
- Claude Code multi-account load balancing
|
|
||||||
- Qwen Code multi-account load balancing
|
|
||||||
- iFlow multi-account load balancing
|
|
||||||
- OpenAI Codex multi-account load balancing
|
|
||||||
- OpenAI-compatible upstream providers via config (e.g., OpenRouter)
|
|
||||||
- Reusable Go SDK for embedding the proxy (see `docs/sdk-usage.md`)
|
|
||||||
|
|
||||||
## Getting Started
|
|
||||||
|
|
||||||
CLIProxyAPI Guides: [https://help.router-for.me/](https://help.router-for.me/)
|
|
||||||
|
|
||||||
## Management API
|
|
||||||
|
|
||||||
see [MANAGEMENT_API.md](https://help.router-for.me/management/api)
|
|
||||||
|
|
||||||
## Amp CLI Support
|
|
||||||
|
|
||||||
CLIProxyAPI includes integrated support for [Amp CLI](https://ampcode.com) and Amp IDE extensions, enabling you to use your Google/ChatGPT/Claude OAuth subscriptions with Amp's coding tools:
|
|
||||||
|
|
||||||
- Provider route aliases for Amp's API patterns (`/api/provider/{provider}/v1...`)
|
|
||||||
- Management proxy for OAuth authentication and account features
|
|
||||||
- Smart model fallback with automatic routing
|
|
||||||
- **Model mapping** to route unavailable models to alternatives (e.g., `claude-opus-4.5` → `claude-sonnet-4`)
|
|
||||||
- Security-first design with localhost-only management endpoints
|
|
||||||
|
|
||||||
**→ [Complete Amp CLI Integration Guide](https://help.router-for.me/agent-client/amp-cli.html)**
|
|
||||||
|
|
||||||
## SDK Docs
|
|
||||||
|
|
||||||
- Usage: [docs/sdk-usage.md](docs/sdk-usage.md)
|
|
||||||
- Advanced (executors & translators): [docs/sdk-advanced.md](docs/sdk-advanced.md)
|
|
||||||
- Access: [docs/sdk-access.md](docs/sdk-access.md)
|
|
||||||
- Watcher: [docs/sdk-watcher.md](docs/sdk-watcher.md)
|
|
||||||
- Custom Provider Example: `examples/custom-provider`
|
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
Contributions are welcome! Please feel free to submit a Pull Request.
|
This project only accepts pull requests that relate to third-party provider support. Any pull requests unrelated to third-party provider support will be rejected.
|
||||||
|
|
||||||
1. Fork the repository
|
If you need to submit any non-third-party provider changes, please open them against the mainline repository.
|
||||||
2. Create your feature branch (`git checkout -b feature/amazing-feature`)
|
|
||||||
3. Commit your changes (`git commit -m 'Add some amazing feature'`)
|
|
||||||
4. Push to the branch (`git push origin feature/amazing-feature`)
|
|
||||||
5. Open a Pull Request
|
|
||||||
|
|
||||||
## Who is with us?
|
|
||||||
|
|
||||||
Those projects are based on CLIProxyAPI:
|
|
||||||
|
|
||||||
### [vibeproxy](https://github.com/automazeio/vibeproxy)
|
|
||||||
|
|
||||||
Native macOS menu bar app to use your Claude Code & ChatGPT subscriptions with AI coding tools - no API keys needed
|
|
||||||
|
|
||||||
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
|
|
||||||
|
|
||||||
Browser-based tool to translate SRT subtitles using your Gemini subscription via CLIProxyAPI with automatic validation/error correction - no API keys needed
|
|
||||||
|
|
||||||
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
|
|
||||||
|
|
||||||
CLI wrapper for instant switching between multiple Claude accounts and alternative models (Gemini, Codex, Antigravity) via CLIProxyAPI OAuth - no API keys needed
|
|
||||||
|
|
||||||
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
|
|
||||||
|
|
||||||
Native macOS GUI for managing CLIProxyAPI: configure providers, model mappings, and endpoints via OAuth - no API keys needed.
|
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
|
|||||||
109
README_CN.md
109
README_CN.md
@@ -1,113 +1,24 @@
|
|||||||
# CLI 代理 API
|
# CLIProxyAPI Plus
|
||||||
|
|
||||||
[English](README.md) | 中文
|
[English](README.md) | 中文
|
||||||
|
|
||||||
一个为 CLI 提供 OpenAI/Gemini/Claude/Codex 兼容 API 接口的代理服务器。
|
这是 [CLIProxyAPI](https://github.com/router-for-me/CLIProxyAPI) 的 Plus 版本,在原有基础上增加了第三方供应商的支持。
|
||||||
|
|
||||||
现已支持通过 OAuth 登录接入 OpenAI Codex(GPT 系列)和 Claude Code。
|
所有的第三方供应商支持都由第三方社区维护者提供,CLIProxyAPI 不提供技术支持。如需取得支持,请与对应的社区维护者联系。
|
||||||
|
|
||||||
您可以使用本地或多账户的CLI方式,通过任何与 OpenAI(包括Responses)/Gemini/Claude 兼容的客户端和SDK进行访问。
|
该 Plus 版本的主线功能与主线功能强制同步。
|
||||||
|
|
||||||
## 赞助商
|
## 与主线版本版本差异
|
||||||
|
|
||||||
[](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
|
- 新增 GitHub Copilot 支持(OAuth 登录),由[em4go](https://github.com/em4go/CLIProxyAPI/tree/feature/github-copilot-auth)提供
|
||||||
|
- 新增 Kiro (AWS CodeWhisperer) 支持 (OAuth 登录), 由[fuko2935](https://github.com/fuko2935/CLIProxyAPI/tree/feature/kiro-integration)、[Ravens2121](https://github.com/Ravens2121/CLIProxyAPIPlus/)提供
|
||||||
本项目由 Z智谱 提供赞助, 他们通过 GLM CODING PLAN 对本项目提供技术支持。
|
|
||||||
|
|
||||||
GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.6,为开发者提供顶尖的编码体验。
|
|
||||||
|
|
||||||
智谱AI为本软件提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
|
|
||||||
|
|
||||||
## 功能特性
|
|
||||||
|
|
||||||
- 为 CLI 模型提供 OpenAI/Gemini/Claude/Codex 兼容的 API 端点
|
|
||||||
- 新增 OpenAI Codex(GPT 系列)支持(OAuth 登录)
|
|
||||||
- 新增 Claude Code 支持(OAuth 登录)
|
|
||||||
- 新增 Qwen Code 支持(OAuth 登录)
|
|
||||||
- 新增 iFlow 支持(OAuth 登录)
|
|
||||||
- 支持流式与非流式响应
|
|
||||||
- 函数调用/工具支持
|
|
||||||
- 多模态输入(文本、图片)
|
|
||||||
- 多账户支持与轮询负载均衡(Gemini、OpenAI、Claude、Qwen 与 iFlow)
|
|
||||||
- 简单的 CLI 身份验证流程(Gemini、OpenAI、Claude、Qwen 与 iFlow)
|
|
||||||
- 支持 Gemini AIStudio API 密钥
|
|
||||||
- 支持 AI Studio Build 多账户轮询
|
|
||||||
- 支持 Gemini CLI 多账户轮询
|
|
||||||
- 支持 Claude Code 多账户轮询
|
|
||||||
- 支持 Qwen Code 多账户轮询
|
|
||||||
- 支持 iFlow 多账户轮询
|
|
||||||
- 支持 OpenAI Codex 多账户轮询
|
|
||||||
- 通过配置接入上游 OpenAI 兼容提供商(例如 OpenRouter)
|
|
||||||
- 可复用的 Go SDK(见 `docs/sdk-usage_CN.md`)
|
|
||||||
|
|
||||||
## 新手入门
|
|
||||||
|
|
||||||
CLIProxyAPI 用户手册: [https://help.router-for.me/](https://help.router-for.me/cn/)
|
|
||||||
|
|
||||||
## 管理 API 文档
|
|
||||||
|
|
||||||
请参见 [MANAGEMENT_API_CN.md](https://help.router-for.me/cn/management/api)
|
|
||||||
|
|
||||||
## Amp CLI 支持
|
|
||||||
|
|
||||||
CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支持,可让你使用自己的 Google/ChatGPT/Claude OAuth 订阅来配合 Amp 编码工具:
|
|
||||||
|
|
||||||
- 提供商路由别名,兼容 Amp 的 API 路径模式(`/api/provider/{provider}/v1...`)
|
|
||||||
- 管理代理,处理 OAuth 认证和账号功能
|
|
||||||
- 智能模型回退与自动路由
|
|
||||||
- 以安全为先的设计,管理端点仅限 localhost
|
|
||||||
|
|
||||||
**→ [Amp CLI 完整集成指南](https://help.router-for.me/cn/agent-client/amp-cli.html)**
|
|
||||||
|
|
||||||
## SDK 文档
|
|
||||||
|
|
||||||
- 使用文档:[docs/sdk-usage_CN.md](docs/sdk-usage_CN.md)
|
|
||||||
- 高级(执行器与翻译器):[docs/sdk-advanced_CN.md](docs/sdk-advanced_CN.md)
|
|
||||||
- 认证: [docs/sdk-access_CN.md](docs/sdk-access_CN.md)
|
|
||||||
- 凭据加载/更新: [docs/sdk-watcher_CN.md](docs/sdk-watcher_CN.md)
|
|
||||||
- 自定义 Provider 示例:`examples/custom-provider`
|
|
||||||
|
|
||||||
## 贡献
|
## 贡献
|
||||||
|
|
||||||
欢迎贡献!请随时提交 Pull Request。
|
该项目仅接受第三方供应商支持的 Pull Request。任何非第三方供应商支持的 Pull Request 都将被拒绝。
|
||||||
|
|
||||||
1. Fork 仓库
|
如果需要提交任何非第三方供应商支持的 Pull Request,请提交到主线版本。
|
||||||
2. 创建您的功能分支(`git checkout -b feature/amazing-feature`)
|
|
||||||
3. 提交您的更改(`git commit -m 'Add some amazing feature'`)
|
|
||||||
4. 推送到分支(`git push origin feature/amazing-feature`)
|
|
||||||
5. 打开 Pull Request
|
|
||||||
|
|
||||||
## 谁与我们在一起?
|
|
||||||
|
|
||||||
这些项目基于 CLIProxyAPI:
|
|
||||||
|
|
||||||
### [vibeproxy](https://github.com/automazeio/vibeproxy)
|
|
||||||
|
|
||||||
一个原生 macOS 菜单栏应用,让您可以使用 Claude Code & ChatGPT 订阅服务和 AI 编程工具,无需 API 密钥。
|
|
||||||
|
|
||||||
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
|
|
||||||
|
|
||||||
一款基于浏览器的 SRT 字幕翻译工具,可通过 CLI 代理 API 使用您的 Gemini 订阅。内置自动验证与错误修正功能,无需 API 密钥。
|
|
||||||
|
|
||||||
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
|
|
||||||
|
|
||||||
CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户和替代模型(Gemini, Codex, Antigravity),无需 API 密钥。
|
|
||||||
|
|
||||||
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
|
|
||||||
|
|
||||||
基于 macOS 平台的原生 CLIProxyAPI GUI:配置供应商、模型映射以及OAuth端点,无需 API 密钥。
|
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
|
|
||||||
|
|
||||||
## 许可证
|
## 许可证
|
||||||
|
|
||||||
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
||||||
|
|
||||||
## 写给所有中国网友的
|
|
||||||
|
|
||||||
QQ 群:188637136
|
|
||||||
|
|
||||||
或
|
|
||||||
|
|
||||||
Telegram 群:https://t.me/CLIProxyAPI
|
|
||||||
@@ -47,6 +47,19 @@ func init() {
|
|||||||
buildinfo.BuildDate = BuildDate
|
buildinfo.BuildDate = BuildDate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setKiroIncognitoMode sets the incognito browser mode for Kiro authentication.
|
||||||
|
// Kiro defaults to incognito mode for multi-account support.
|
||||||
|
// Users can explicitly override with --incognito or --no-incognito flags.
|
||||||
|
func setKiroIncognitoMode(cfg *config.Config, useIncognito, noIncognito bool) {
|
||||||
|
if useIncognito {
|
||||||
|
cfg.IncognitoBrowser = true
|
||||||
|
} else if noIncognito {
|
||||||
|
cfg.IncognitoBrowser = false
|
||||||
|
} else {
|
||||||
|
cfg.IncognitoBrowser = true // Kiro default
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// main is the entry point of the application.
|
// main is the entry point of the application.
|
||||||
// It parses command-line flags, loads configuration, and starts the appropriate
|
// It parses command-line flags, loads configuration, and starts the appropriate
|
||||||
// service based on the provided flags (login, codex-login, or server mode).
|
// service based on the provided flags (login, codex-login, or server mode).
|
||||||
@@ -62,10 +75,17 @@ func main() {
|
|||||||
var iflowCookie bool
|
var iflowCookie bool
|
||||||
var noBrowser bool
|
var noBrowser bool
|
||||||
var antigravityLogin bool
|
var antigravityLogin bool
|
||||||
|
var kiroLogin bool
|
||||||
|
var kiroGoogleLogin bool
|
||||||
|
var kiroAWSLogin bool
|
||||||
|
var kiroImport bool
|
||||||
|
var githubCopilotLogin bool
|
||||||
var projectID string
|
var projectID string
|
||||||
var vertexImport string
|
var vertexImport string
|
||||||
var configPath string
|
var configPath string
|
||||||
var password string
|
var password string
|
||||||
|
var noIncognito bool
|
||||||
|
var useIncognito bool
|
||||||
|
|
||||||
// Define command-line flags for different operation modes.
|
// Define command-line flags for different operation modes.
|
||||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||||
@@ -75,7 +95,14 @@ func main() {
|
|||||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||||
|
flag.BoolVar(&useIncognito, "incognito", false, "Open browser in incognito/private mode for OAuth (useful for multiple accounts)")
|
||||||
|
flag.BoolVar(&noIncognito, "no-incognito", false, "Force disable incognito mode (uses existing browser session)")
|
||||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||||
|
flag.BoolVar(&kiroLogin, "kiro-login", false, "Login to Kiro using Google OAuth")
|
||||||
|
flag.BoolVar(&kiroGoogleLogin, "kiro-google-login", false, "Login to Kiro using Google OAuth (same as --kiro-login)")
|
||||||
|
flag.BoolVar(&kiroAWSLogin, "kiro-aws-login", false, "Login to Kiro using AWS Builder ID (device code flow)")
|
||||||
|
flag.BoolVar(&kiroImport, "kiro-import", false, "Import Kiro token from Kiro IDE (~/.aws/sso/cache/kiro-auth-token.json)")
|
||||||
|
flag.BoolVar(&githubCopilotLogin, "github-copilot-login", false, "Login to GitHub Copilot using device flow")
|
||||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||||
@@ -453,6 +480,9 @@ func main() {
|
|||||||
} else if antigravityLogin {
|
} else if antigravityLogin {
|
||||||
// Handle Antigravity login
|
// Handle Antigravity login
|
||||||
cmd.DoAntigravityLogin(cfg, options)
|
cmd.DoAntigravityLogin(cfg, options)
|
||||||
|
} else if githubCopilotLogin {
|
||||||
|
// Handle GitHub Copilot login
|
||||||
|
cmd.DoGitHubCopilotLogin(cfg, options)
|
||||||
} else if codexLogin {
|
} else if codexLogin {
|
||||||
// Handle Codex login
|
// Handle Codex login
|
||||||
cmd.DoCodexLogin(cfg, options)
|
cmd.DoCodexLogin(cfg, options)
|
||||||
@@ -465,6 +495,26 @@ func main() {
|
|||||||
cmd.DoIFlowLogin(cfg, options)
|
cmd.DoIFlowLogin(cfg, options)
|
||||||
} else if iflowCookie {
|
} else if iflowCookie {
|
||||||
cmd.DoIFlowCookieAuth(cfg, options)
|
cmd.DoIFlowCookieAuth(cfg, options)
|
||||||
|
} else if kiroLogin {
|
||||||
|
// For Kiro auth, default to incognito mode for multi-account support
|
||||||
|
// Users can explicitly override with --no-incognito
|
||||||
|
// Note: This config mutation is safe - auth commands exit after completion
|
||||||
|
// and don't share config with StartService (which is in the else branch)
|
||||||
|
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||||
|
cmd.DoKiroLogin(cfg, options)
|
||||||
|
} else if kiroGoogleLogin {
|
||||||
|
// For Kiro auth, default to incognito mode for multi-account support
|
||||||
|
// Users can explicitly override with --no-incognito
|
||||||
|
// Note: This config mutation is safe - auth commands exit after completion
|
||||||
|
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||||
|
cmd.DoKiroGoogleLogin(cfg, options)
|
||||||
|
} else if kiroAWSLogin {
|
||||||
|
// For Kiro auth, default to incognito mode for multi-account support
|
||||||
|
// Users can explicitly override with --no-incognito
|
||||||
|
setKiroIncognitoMode(cfg, useIncognito, noIncognito)
|
||||||
|
cmd.DoKiroAWSLogin(cfg, options)
|
||||||
|
} else if kiroImport {
|
||||||
|
cmd.DoKiroImport(cfg, options)
|
||||||
} else {
|
} else {
|
||||||
// In cloud deploy mode without config file, just wait for shutdown signals
|
// In cloud deploy mode without config file, just wait for shutdown signals
|
||||||
if isCloudDeploy && !configFileExists {
|
if isCloudDeploy && !configFileExists {
|
||||||
|
|||||||
@@ -39,6 +39,11 @@ api-keys:
|
|||||||
# Enable debug logging
|
# Enable debug logging
|
||||||
debug: false
|
debug: false
|
||||||
|
|
||||||
|
# Open OAuth URLs in incognito/private browser mode.
|
||||||
|
# Useful when you want to login with a different account without logging out from your current session.
|
||||||
|
# Default: false (but Kiro auth defaults to true for multi-account support)
|
||||||
|
incognito-browser: true
|
||||||
|
|
||||||
# When true, write application logs to rotating files instead of stdout
|
# When true, write application logs to rotating files instead of stdout
|
||||||
logging-to-file: false
|
logging-to-file: false
|
||||||
|
|
||||||
@@ -106,6 +111,16 @@ ws-auth: false
|
|||||||
# - "*-thinking" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
|
# - "*-thinking" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
|
||||||
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
|
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
|
||||||
|
|
||||||
|
# Kiro (AWS CodeWhisperer) configuration
|
||||||
|
# Note: Kiro API currently only operates in us-east-1 region
|
||||||
|
#kiro:
|
||||||
|
# - token-file: "~/.aws/sso/cache/kiro-auth-token.json" # path to Kiro token file
|
||||||
|
# agent-task-type: "" # optional: "vibe" or empty (API default)
|
||||||
|
# - access-token: "aoaAAAAA..." # or provide tokens directly
|
||||||
|
# refresh-token: "aorAAAAA..."
|
||||||
|
# profile-arn: "arn:aws:codewhisperer:us-east-1:..."
|
||||||
|
# proxy-url: "socks5://proxy.example.com:1080" # optional: proxy override
|
||||||
|
|
||||||
# OpenAI compatibility providers
|
# OpenAI compatibility providers
|
||||||
# openai-compatibility:
|
# openai-compatibility:
|
||||||
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
services:
|
services:
|
||||||
cli-proxy-api:
|
cli-proxy-api:
|
||||||
image: ${CLI_PROXY_IMAGE:-eceasy/cli-proxy-api:latest}
|
image: ${CLI_PROXY_IMAGE:-eceasy/cli-proxy-api-plus:latest}
|
||||||
pull_policy: always
|
pull_policy: always
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
@@ -9,7 +9,7 @@ services:
|
|||||||
VERSION: ${VERSION:-dev}
|
VERSION: ${VERSION:-dev}
|
||||||
COMMIT: ${COMMIT:-none}
|
COMMIT: ${COMMIT:-none}
|
||||||
BUILD_DATE: ${BUILD_DATE:-unknown}
|
BUILD_DATE: ${BUILD_DATE:-unknown}
|
||||||
container_name: cli-proxy-api
|
container_name: cli-proxy-api-plus
|
||||||
# env_file:
|
# env_file:
|
||||||
# - .env
|
# - .env
|
||||||
environment:
|
environment:
|
||||||
|
|||||||
3
go.mod
3
go.mod
@@ -13,14 +13,15 @@ require (
|
|||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/klauspost/compress v1.17.4
|
github.com/klauspost/compress v1.17.4
|
||||||
github.com/minio/minio-go/v7 v7.0.66
|
github.com/minio/minio-go/v7 v7.0.66
|
||||||
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
||||||
github.com/sirupsen/logrus v1.9.3
|
github.com/sirupsen/logrus v1.9.3
|
||||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
|
||||||
github.com/tidwall/gjson v1.18.0
|
github.com/tidwall/gjson v1.18.0
|
||||||
github.com/tidwall/sjson v1.2.5
|
github.com/tidwall/sjson v1.2.5
|
||||||
github.com/tiktoken-go/tokenizer v0.7.0
|
github.com/tiktoken-go/tokenizer v0.7.0
|
||||||
golang.org/x/crypto v0.45.0
|
golang.org/x/crypto v0.45.0
|
||||||
golang.org/x/net v0.47.0
|
golang.org/x/net v0.47.0
|
||||||
golang.org/x/oauth2 v0.30.0
|
golang.org/x/oauth2 v0.30.0
|
||||||
|
golang.org/x/term v0.36.0
|
||||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|||||||
5
go.sum
5
go.sum
@@ -116,6 +116,8 @@ github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6
|
|||||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||||
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
||||||
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
||||||
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||||
|
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||||
@@ -126,8 +128,6 @@ github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
|
|||||||
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
|
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
|
||||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA=
|
|
||||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog=
|
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
@@ -169,6 +169,7 @@ golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKl
|
|||||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ package management
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -23,6 +26,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||||
|
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
@@ -37,9 +41,32 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
oauthStatus = make(map[string]string)
|
oauthStatus = make(map[string]string)
|
||||||
|
oauthStatusMutex sync.RWMutex
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// getOAuthStatus safely retrieves an OAuth status
|
||||||
|
func getOAuthStatus(key string) (string, bool) {
|
||||||
|
oauthStatusMutex.RLock()
|
||||||
|
defer oauthStatusMutex.RUnlock()
|
||||||
|
status, ok := oauthStatus[key]
|
||||||
|
return status, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// setOAuthStatus safely sets an OAuth status
|
||||||
|
func setOAuthStatus(key string, status string) {
|
||||||
|
oauthStatusMutex.Lock()
|
||||||
|
defer oauthStatusMutex.Unlock()
|
||||||
|
oauthStatus[key] = status
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteOAuthStatus safely deletes an OAuth status
|
||||||
|
func deleteOAuthStatus(key string) {
|
||||||
|
oauthStatusMutex.Lock()
|
||||||
|
defer oauthStatusMutex.Unlock()
|
||||||
|
delete(oauthStatus, key)
|
||||||
|
}
|
||||||
|
|
||||||
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -812,7 +839,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
deadline := time.Now().Add(timeout)
|
deadline := time.Now().Add(timeout)
|
||||||
for {
|
for {
|
||||||
if time.Now().After(deadline) {
|
if time.Now().After(deadline) {
|
||||||
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
setOAuthStatus(state, "Timeout waiting for OAuth callback")
|
||||||
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||||
}
|
}
|
||||||
data, errRead := os.ReadFile(path)
|
data, errRead := os.ReadFile(path)
|
||||||
@@ -837,13 +864,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
if errStr := resultMap["error"]; errStr != "" {
|
if errStr := resultMap["error"]; errStr != "" {
|
||||||
oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
|
oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
|
||||||
log.Error(claude.GetUserFriendlyMessage(oauthErr))
|
log.Error(claude.GetUserFriendlyMessage(oauthErr))
|
||||||
oauthStatus[state] = "Bad request"
|
setOAuthStatus(state, "Bad request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if resultMap["state"] != state {
|
if resultMap["state"] != state {
|
||||||
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
|
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
|
||||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||||
oauthStatus[state] = "State code error"
|
setOAuthStatus(state, "State code error")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -876,7 +903,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
|
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
|
||||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||||
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
setOAuthStatus(state, "Failed to exchange authorization code for tokens")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -887,7 +914,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
respBody, _ := io.ReadAll(resp.Body)
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||||
oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)
|
setOAuthStatus(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
var tResp struct {
|
var tResp struct {
|
||||||
@@ -900,7 +927,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
|
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
|
||||||
log.Errorf("failed to parse token response: %v", errU)
|
log.Errorf("failed to parse token response: %v", errU)
|
||||||
oauthStatus[state] = "Failed to parse token response"
|
setOAuthStatus(state, "Failed to parse token response")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
bundle := &claude.ClaudeAuthBundle{
|
bundle := &claude.ClaudeAuthBundle{
|
||||||
@@ -925,7 +952,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
if errSave != nil {
|
if errSave != nil {
|
||||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||||
oauthStatus[state] = "Failed to save authentication tokens"
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -934,10 +961,10 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
|||||||
fmt.Println("API key obtained and saved")
|
fmt.Println("API key obtained and saved")
|
||||||
}
|
}
|
||||||
fmt.Println("You can now use Claude services through this CLI")
|
fmt.Println("You can now use Claude services through this CLI")
|
||||||
delete(oauthStatus, state)
|
deleteOAuthStatus(state)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
oauthStatus[state] = ""
|
setOAuthStatus(state, "")
|
||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -996,7 +1023,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
for {
|
for {
|
||||||
if time.Now().After(deadline) {
|
if time.Now().After(deadline) {
|
||||||
log.Error("oauth flow timed out")
|
log.Error("oauth flow timed out")
|
||||||
oauthStatus[state] = "OAuth flow timed out"
|
setOAuthStatus(state, "OAuth flow timed out")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||||
@@ -1005,13 +1032,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
_ = os.Remove(waitFile)
|
_ = os.Remove(waitFile)
|
||||||
if errStr := m["error"]; errStr != "" {
|
if errStr := m["error"]; errStr != "" {
|
||||||
log.Errorf("Authentication failed: %s", errStr)
|
log.Errorf("Authentication failed: %s", errStr)
|
||||||
oauthStatus[state] = "Authentication failed"
|
setOAuthStatus(state, "Authentication failed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
authCode = m["code"]
|
authCode = m["code"]
|
||||||
if authCode == "" {
|
if authCode == "" {
|
||||||
log.Errorf("Authentication failed: code not found")
|
log.Errorf("Authentication failed: code not found")
|
||||||
oauthStatus[state] = "Authentication failed: code not found"
|
setOAuthStatus(state, "Authentication failed: code not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
@@ -1023,7 +1050,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
token, err := conf.Exchange(ctx, authCode)
|
token, err := conf.Exchange(ctx, authCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to exchange token: %v", err)
|
log.Errorf("Failed to exchange token: %v", err)
|
||||||
oauthStatus[state] = "Failed to exchange token"
|
setOAuthStatus(state, "Failed to exchange token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1034,7 +1061,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||||
if errNewRequest != nil {
|
if errNewRequest != nil {
|
||||||
log.Errorf("Could not get user info: %v", errNewRequest)
|
log.Errorf("Could not get user info: %v", errNewRequest)
|
||||||
oauthStatus[state] = "Could not get user info"
|
setOAuthStatus(state, "Could not get user info")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
@@ -1043,7 +1070,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
resp, errDo := authHTTPClient.Do(req)
|
resp, errDo := authHTTPClient.Do(req)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
log.Errorf("Failed to execute request: %v", errDo)
|
log.Errorf("Failed to execute request: %v", errDo)
|
||||||
oauthStatus[state] = "Failed to execute request"
|
setOAuthStatus(state, "Failed to execute request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -1055,7 +1082,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)
|
setOAuthStatus(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1064,7 +1091,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
fmt.Printf("Authenticated user email: %s\n", email)
|
fmt.Printf("Authenticated user email: %s\n", email)
|
||||||
} else {
|
} else {
|
||||||
fmt.Println("Failed to get user email from token")
|
fmt.Println("Failed to get user email from token")
|
||||||
oauthStatus[state] = "Failed to get user email from token"
|
setOAuthStatus(state, "Failed to get user email from token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Marshal/unmarshal oauth2.Token to generic map and enrich fields
|
// Marshal/unmarshal oauth2.Token to generic map and enrich fields
|
||||||
@@ -1072,7 +1099,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
jsonData, _ := json.Marshal(token)
|
jsonData, _ := json.Marshal(token)
|
||||||
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
|
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
|
||||||
log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
|
log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
|
||||||
oauthStatus[state] = "Failed to unmarshal token"
|
setOAuthStatus(state, "Failed to unmarshal token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1098,7 +1125,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true)
|
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true)
|
||||||
if errGetClient != nil {
|
if errGetClient != nil {
|
||||||
log.Errorf("failed to get authenticated client: %v", errGetClient)
|
log.Errorf("failed to get authenticated client: %v", errGetClient)
|
||||||
oauthStatus[state] = "Failed to get authenticated client"
|
setOAuthStatus(state, "Failed to get authenticated client")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
fmt.Println("Authentication successful.")
|
fmt.Println("Authentication successful.")
|
||||||
@@ -1108,12 +1135,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
||||||
if errAll != nil {
|
if errAll != nil {
|
||||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
||||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
setOAuthStatus(state, "Failed to complete Gemini CLI onboarding")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
||||||
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
||||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
setOAuthStatus(state, "Failed to verify Cloud AI API status")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ts.ProjectID = strings.Join(projects, ",")
|
ts.ProjectID = strings.Join(projects, ",")
|
||||||
@@ -1121,26 +1148,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
} else {
|
} else {
|
||||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
setOAuthStatus(state, "Failed to complete Gemini CLI onboarding")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||||
log.Error("Onboarding did not return a project ID")
|
log.Error("Onboarding did not return a project ID")
|
||||||
oauthStatus[state] = "Failed to resolve project ID"
|
setOAuthStatus(state, "Failed to resolve project ID")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||||
if errCheck != nil {
|
if errCheck != nil {
|
||||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
setOAuthStatus(state, "Failed to verify Cloud AI API status")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ts.Checked = isChecked
|
ts.Checked = isChecked
|
||||||
if !isChecked {
|
if !isChecked {
|
||||||
log.Error("Cloud AI API is not enabled for the selected project")
|
log.Error("Cloud AI API is not enabled for the selected project")
|
||||||
oauthStatus[state] = "Cloud AI API not enabled"
|
setOAuthStatus(state, "Cloud AI API not enabled")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1163,15 +1190,15 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
if errSave != nil {
|
if errSave != nil {
|
||||||
log.Errorf("Failed to save token to file: %v", errSave)
|
log.Errorf("Failed to save token to file: %v", errSave)
|
||||||
oauthStatus[state] = "Failed to save token to file"
|
setOAuthStatus(state, "Failed to save token to file")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(oauthStatus, state)
|
deleteOAuthStatus(state)
|
||||||
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
|
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
oauthStatus[state] = ""
|
setOAuthStatus(state, "")
|
||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1235,7 +1262,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
if time.Now().After(deadline) {
|
if time.Now().After(deadline) {
|
||||||
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
|
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
|
||||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||||
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
setOAuthStatus(state, "Timeout waiting for OAuth callback")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||||
@@ -1245,12 +1272,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
if errStr := m["error"]; errStr != "" {
|
if errStr := m["error"]; errStr != "" {
|
||||||
oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
|
oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
|
||||||
log.Error(codex.GetUserFriendlyMessage(oauthErr))
|
log.Error(codex.GetUserFriendlyMessage(oauthErr))
|
||||||
oauthStatus[state] = "Bad Request"
|
setOAuthStatus(state, "Bad Request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if m["state"] != state {
|
if m["state"] != state {
|
||||||
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
|
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
|
||||||
oauthStatus[state] = "State code error"
|
setOAuthStatus(state, "State code error")
|
||||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1281,14 +1308,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
resp, errDo := httpClient.Do(req)
|
resp, errDo := httpClient.Do(req)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
|
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
|
||||||
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
setOAuthStatus(state, "Failed to exchange authorization code for tokens")
|
||||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
respBody, _ := io.ReadAll(resp.Body)
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)
|
setOAuthStatus(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
|
||||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1299,7 +1326,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
ExpiresIn int `json:"expires_in"`
|
ExpiresIn int `json:"expires_in"`
|
||||||
}
|
}
|
||||||
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
|
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
|
||||||
oauthStatus[state] = "Failed to parse token response"
|
setOAuthStatus(state, "Failed to parse token response")
|
||||||
log.Errorf("failed to parse token response: %v", errU)
|
log.Errorf("failed to parse token response: %v", errU)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1337,8 +1364,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
if errSave != nil {
|
if errSave != nil {
|
||||||
oauthStatus[state] = "Failed to save authentication tokens"
|
|
||||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||||
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||||
@@ -1346,10 +1373,10 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
|||||||
fmt.Println("API key obtained and saved")
|
fmt.Println("API key obtained and saved")
|
||||||
}
|
}
|
||||||
fmt.Println("You can now use Codex services through this CLI")
|
fmt.Println("You can now use Codex services through this CLI")
|
||||||
delete(oauthStatus, state)
|
deleteOAuthStatus(state)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
oauthStatus[state] = ""
|
setOAuthStatus(state, "")
|
||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1416,7 +1443,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
for {
|
for {
|
||||||
if time.Now().After(deadline) {
|
if time.Now().After(deadline) {
|
||||||
log.Error("oauth flow timed out")
|
log.Error("oauth flow timed out")
|
||||||
oauthStatus[state] = "OAuth flow timed out"
|
setOAuthStatus(state, "OAuth flow timed out")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
|
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
|
||||||
@@ -1425,18 +1452,18 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
_ = os.Remove(waitFile)
|
_ = os.Remove(waitFile)
|
||||||
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
|
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
|
||||||
log.Errorf("Authentication failed: %s", errStr)
|
log.Errorf("Authentication failed: %s", errStr)
|
||||||
oauthStatus[state] = "Authentication failed"
|
setOAuthStatus(state, "Authentication failed")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
|
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
|
||||||
log.Errorf("Authentication failed: state mismatch")
|
log.Errorf("Authentication failed: state mismatch")
|
||||||
oauthStatus[state] = "Authentication failed: state mismatch"
|
setOAuthStatus(state, "Authentication failed: state mismatch")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
authCode = strings.TrimSpace(payload["code"])
|
authCode = strings.TrimSpace(payload["code"])
|
||||||
if authCode == "" {
|
if authCode == "" {
|
||||||
log.Error("Authentication failed: code not found")
|
log.Error("Authentication failed: code not found")
|
||||||
oauthStatus[state] = "Authentication failed: code not found"
|
setOAuthStatus(state, "Authentication failed: code not found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
@@ -1455,7 +1482,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
|
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
|
||||||
if errNewRequest != nil {
|
if errNewRequest != nil {
|
||||||
log.Errorf("Failed to build token request: %v", errNewRequest)
|
log.Errorf("Failed to build token request: %v", errNewRequest)
|
||||||
oauthStatus[state] = "Failed to build token request"
|
setOAuthStatus(state, "Failed to build token request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
@@ -1463,7 +1490,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
resp, errDo := httpClient.Do(req)
|
resp, errDo := httpClient.Do(req)
|
||||||
if errDo != nil {
|
if errDo != nil {
|
||||||
log.Errorf("Failed to execute token request: %v", errDo)
|
log.Errorf("Failed to execute token request: %v", errDo)
|
||||||
oauthStatus[state] = "Failed to exchange token"
|
setOAuthStatus(state, "Failed to exchange token")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -1475,7 +1502,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
oauthStatus[state] = fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)
|
setOAuthStatus(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1487,7 +1514,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
|
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
|
||||||
log.Errorf("Failed to parse token response: %v", errDecode)
|
log.Errorf("Failed to parse token response: %v", errDecode)
|
||||||
oauthStatus[state] = "Failed to parse token response"
|
setOAuthStatus(state, "Failed to parse token response")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1496,7 +1523,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||||
if errInfoReq != nil {
|
if errInfoReq != nil {
|
||||||
log.Errorf("Failed to build user info request: %v", errInfoReq)
|
log.Errorf("Failed to build user info request: %v", errInfoReq)
|
||||||
oauthStatus[state] = "Failed to build user info request"
|
setOAuthStatus(state, "Failed to build user info request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
||||||
@@ -1504,7 +1531,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
infoResp, errInfo := httpClient.Do(infoReq)
|
infoResp, errInfo := httpClient.Do(infoReq)
|
||||||
if errInfo != nil {
|
if errInfo != nil {
|
||||||
log.Errorf("Failed to execute user info request: %v", errInfo)
|
log.Errorf("Failed to execute user info request: %v", errInfo)
|
||||||
oauthStatus[state] = "Failed to execute user info request"
|
setOAuthStatus(state, "Failed to execute user info request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -1523,7 +1550,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
} else {
|
} else {
|
||||||
bodyBytes, _ := io.ReadAll(infoResp.Body)
|
bodyBytes, _ := io.ReadAll(infoResp.Body)
|
||||||
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
|
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
|
||||||
oauthStatus[state] = fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)
|
setOAuthStatus(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1571,11 +1598,11 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
if errSave != nil {
|
if errSave != nil {
|
||||||
log.Errorf("Failed to save token to file: %v", errSave)
|
log.Errorf("Failed to save token to file: %v", errSave)
|
||||||
oauthStatus[state] = "Failed to save token to file"
|
setOAuthStatus(state, "Failed to save token to file")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(oauthStatus, state)
|
deleteOAuthStatus(state)
|
||||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||||
if projectID != "" {
|
if projectID != "" {
|
||||||
fmt.Printf("Using GCP project: %s\n", projectID)
|
fmt.Printf("Using GCP project: %s\n", projectID)
|
||||||
@@ -1583,7 +1610,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
fmt.Println("You can now use Antigravity services through this CLI")
|
fmt.Println("You can now use Antigravity services through this CLI")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
oauthStatus[state] = ""
|
setOAuthStatus(state, "")
|
||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1609,7 +1636,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|||||||
fmt.Println("Waiting for authentication...")
|
fmt.Println("Waiting for authentication...")
|
||||||
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
|
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
|
||||||
if errPollForToken != nil {
|
if errPollForToken != nil {
|
||||||
oauthStatus[state] = "Authentication failed"
|
setOAuthStatus(state, "Authentication failed")
|
||||||
fmt.Printf("Authentication failed: %v\n", errPollForToken)
|
fmt.Printf("Authentication failed: %v\n", errPollForToken)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1628,16 +1655,16 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
|||||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
if errSave != nil {
|
if errSave != nil {
|
||||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||||
oauthStatus[state] = "Failed to save authentication tokens"
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||||
fmt.Println("You can now use Qwen services through this CLI")
|
fmt.Println("You can now use Qwen services through this CLI")
|
||||||
delete(oauthStatus, state)
|
deleteOAuthStatus(state)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
oauthStatus[state] = ""
|
setOAuthStatus(state, "")
|
||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1676,7 +1703,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
|||||||
var resultMap map[string]string
|
var resultMap map[string]string
|
||||||
for {
|
for {
|
||||||
if time.Now().After(deadline) {
|
if time.Now().After(deadline) {
|
||||||
oauthStatus[state] = "Authentication failed"
|
setOAuthStatus(state, "Authentication failed")
|
||||||
fmt.Println("Authentication failed: timeout waiting for callback")
|
fmt.Println("Authentication failed: timeout waiting for callback")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1689,26 +1716,26 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
|
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
|
||||||
oauthStatus[state] = "Authentication failed"
|
setOAuthStatus(state, "Authentication failed")
|
||||||
fmt.Printf("Authentication failed: %s\n", errStr)
|
fmt.Printf("Authentication failed: %s\n", errStr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
|
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
|
||||||
oauthStatus[state] = "Authentication failed"
|
setOAuthStatus(state, "Authentication failed")
|
||||||
fmt.Println("Authentication failed: state mismatch")
|
fmt.Println("Authentication failed: state mismatch")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
code := strings.TrimSpace(resultMap["code"])
|
code := strings.TrimSpace(resultMap["code"])
|
||||||
if code == "" {
|
if code == "" {
|
||||||
oauthStatus[state] = "Authentication failed"
|
setOAuthStatus(state, "Authentication failed")
|
||||||
fmt.Println("Authentication failed: code missing")
|
fmt.Println("Authentication failed: code missing")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
|
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
|
||||||
if errExchange != nil {
|
if errExchange != nil {
|
||||||
oauthStatus[state] = "Authentication failed"
|
setOAuthStatus(state, "Authentication failed")
|
||||||
fmt.Printf("Authentication failed: %v\n", errExchange)
|
fmt.Printf("Authentication failed: %v\n", errExchange)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1730,8 +1757,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
|||||||
|
|
||||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
if errSave != nil {
|
if errSave != nil {
|
||||||
oauthStatus[state] = "Failed to save authentication tokens"
|
|
||||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||||
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1740,10 +1767,10 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
|||||||
fmt.Println("API key obtained and saved")
|
fmt.Println("API key obtained and saved")
|
||||||
}
|
}
|
||||||
fmt.Println("You can now use iFlow services through this CLI")
|
fmt.Println("You can now use iFlow services through this CLI")
|
||||||
delete(oauthStatus, state)
|
deleteOAuthStatus(state)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
oauthStatus[state] = ""
|
setOAuthStatus(state, "")
|
||||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2180,9 +2207,35 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
|||||||
|
|
||||||
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
if err, ok := oauthStatus[state]; ok {
|
if statusValue, ok := getOAuthStatus(state); ok {
|
||||||
if err != "" {
|
if statusValue != "" {
|
||||||
c.JSON(200, gin.H{"status": "error", "error": err})
|
// Check for device_code prefix (Kiro AWS Builder ID flow)
|
||||||
|
// Format: "device_code|verification_url|user_code"
|
||||||
|
// Using "|" as separator because URLs contain ":"
|
||||||
|
if strings.HasPrefix(statusValue, "device_code|") {
|
||||||
|
parts := strings.SplitN(statusValue, "|", 3)
|
||||||
|
if len(parts) == 3 {
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"status": "device_code",
|
||||||
|
"verification_url": parts[1],
|
||||||
|
"user_code": parts[2],
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check for auth_url prefix (Kiro social auth flow)
|
||||||
|
// Format: "auth_url|url"
|
||||||
|
// Using "|" as separator because URLs contain ":"
|
||||||
|
if strings.HasPrefix(statusValue, "auth_url|") {
|
||||||
|
authURL := strings.TrimPrefix(statusValue, "auth_url|")
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"status": "auth_url",
|
||||||
|
"url": authURL,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Otherwise treat as error
|
||||||
|
c.JSON(200, gin.H{"status": "error", "error": statusValue})
|
||||||
} else {
|
} else {
|
||||||
c.JSON(200, gin.H{"status": "wait"})
|
c.JSON(200, gin.H{"status": "wait"})
|
||||||
return
|
return
|
||||||
@@ -2190,5 +2243,297 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
|
|||||||
} else {
|
} else {
|
||||||
c.JSON(200, gin.H{"status": "ok"})
|
c.JSON(200, gin.H{"status": "ok"})
|
||||||
}
|
}
|
||||||
delete(oauthStatus, state)
|
deleteOAuthStatus(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
const kiroCallbackPort = 9876
|
||||||
|
|
||||||
|
func (h *Handler) RequestKiroToken(c *gin.Context) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Get the login method from query parameter (default: aws for device code flow)
|
||||||
|
method := strings.ToLower(strings.TrimSpace(c.Query("method")))
|
||||||
|
if method == "" {
|
||||||
|
method = "aws"
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Initializing Kiro authentication...")
|
||||||
|
|
||||||
|
state := fmt.Sprintf("kiro-%d", time.Now().UnixNano())
|
||||||
|
|
||||||
|
switch method {
|
||||||
|
case "aws", "builder-id":
|
||||||
|
// AWS Builder ID uses device code flow (no callback needed)
|
||||||
|
go func() {
|
||||||
|
ssoClient := kiroauth.NewSSOOIDCClient(h.cfg)
|
||||||
|
|
||||||
|
// Step 1: Register client
|
||||||
|
fmt.Println("Registering client...")
|
||||||
|
regResp, err := ssoClient.RegisterClient(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to register client: %v", err)
|
||||||
|
setOAuthStatus(state, "Failed to register client")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Start device authorization
|
||||||
|
fmt.Println("Starting device authorization...")
|
||||||
|
authResp, err := ssoClient.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to start device auth: %v", err)
|
||||||
|
setOAuthStatus(state, "Failed to start device authorization")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the verification URL for the frontend to display
|
||||||
|
// Using "|" as separator because URLs contain ":"
|
||||||
|
setOAuthStatus(state, "device_code|"+authResp.VerificationURIComplete+"|"+authResp.UserCode)
|
||||||
|
|
||||||
|
// Step 3: Poll for token
|
||||||
|
fmt.Println("Waiting for authorization...")
|
||||||
|
interval := 5 * time.Second
|
||||||
|
if authResp.Interval > 0 {
|
||||||
|
interval = time.Duration(authResp.Interval) * time.Second
|
||||||
|
}
|
||||||
|
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
setOAuthStatus(state, "Authorization cancelled")
|
||||||
|
return
|
||||||
|
case <-time.After(interval):
|
||||||
|
tokenResp, err := ssoClient.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
|
||||||
|
if err != nil {
|
||||||
|
errStr := err.Error()
|
||||||
|
if strings.Contains(errStr, "authorization_pending") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.Contains(errStr, "slow_down") {
|
||||||
|
interval += 5 * time.Second
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Errorf("Token creation failed: %v", err)
|
||||||
|
setOAuthStatus(state, "Token creation failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success! Save the token
|
||||||
|
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||||
|
email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||||
|
|
||||||
|
idPart := kiroauth.SanitizeEmailForFilename(email)
|
||||||
|
if idPart == "" {
|
||||||
|
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
||||||
|
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: "kiro",
|
||||||
|
FileName: fileName,
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "kiro",
|
||||||
|
"access_token": tokenResp.AccessToken,
|
||||||
|
"refresh_token": tokenResp.RefreshToken,
|
||||||
|
"expires_at": expiresAt.Format(time.RFC3339),
|
||||||
|
"auth_method": "builder-id",
|
||||||
|
"provider": "AWS",
|
||||||
|
"client_id": regResp.ClientID,
|
||||||
|
"client_secret": regResp.ClientSecret,
|
||||||
|
"email": email,
|
||||||
|
"last_refresh": now.Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
|
if errSave != nil {
|
||||||
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||||
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||||
|
if email != "" {
|
||||||
|
fmt.Printf("Authenticated as: %s\n", email)
|
||||||
|
}
|
||||||
|
deleteOAuthStatus(state)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setOAuthStatus(state, "Authorization timed out")
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Return immediately with the state for polling
|
||||||
|
c.JSON(200, gin.H{"status": "ok", "state": state, "method": "device_code"})
|
||||||
|
|
||||||
|
case "google", "github":
|
||||||
|
// Social auth uses protocol handler - for WEB UI we use a callback forwarder
|
||||||
|
provider := "Google"
|
||||||
|
if method == "github" {
|
||||||
|
provider = "Github"
|
||||||
|
}
|
||||||
|
|
||||||
|
isWebUI := isWebUIRequest(c)
|
||||||
|
if isWebUI {
|
||||||
|
targetURL, errTarget := h.managementCallbackURL("/kiro/callback")
|
||||||
|
if errTarget != nil {
|
||||||
|
log.WithError(errTarget).Error("failed to compute kiro callback target")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, errStart := startCallbackForwarder(kiroCallbackPort, "kiro", targetURL); errStart != nil {
|
||||||
|
log.WithError(errStart).Error("failed to start kiro callback forwarder")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if isWebUI {
|
||||||
|
defer stopCallbackForwarder(kiroCallbackPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
socialClient := kiroauth.NewSocialAuthClient(h.cfg)
|
||||||
|
|
||||||
|
// Generate PKCE codes
|
||||||
|
codeVerifier, codeChallenge, err := generateKiroPKCE()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to generate PKCE: %v", err)
|
||||||
|
setOAuthStatus(state, "Failed to generate PKCE")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build login URL
|
||||||
|
authURL := fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account",
|
||||||
|
"https://prod.us-east-1.auth.desktop.kiro.dev",
|
||||||
|
provider,
|
||||||
|
url.QueryEscape(kiroauth.KiroRedirectURI),
|
||||||
|
codeChallenge,
|
||||||
|
state,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store auth URL for frontend
|
||||||
|
// Using "|" as separator because URLs contain ":"
|
||||||
|
setOAuthStatus(state, "auth_url|"+authURL)
|
||||||
|
|
||||||
|
// Wait for callback file
|
||||||
|
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-kiro-%s.oauth", state))
|
||||||
|
deadline := time.Now().Add(5 * time.Minute)
|
||||||
|
|
||||||
|
for {
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
log.Error("oauth flow timed out")
|
||||||
|
setOAuthStatus(state, "OAuth flow timed out")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||||
|
var m map[string]string
|
||||||
|
_ = json.Unmarshal(data, &m)
|
||||||
|
_ = os.Remove(waitFile)
|
||||||
|
if errStr := m["error"]; errStr != "" {
|
||||||
|
log.Errorf("Authentication failed: %s", errStr)
|
||||||
|
setOAuthStatus(state, "Authentication failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m["state"] != state {
|
||||||
|
log.Errorf("State mismatch")
|
||||||
|
setOAuthStatus(state, "State mismatch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
code := m["code"]
|
||||||
|
if code == "" {
|
||||||
|
log.Error("No authorization code received")
|
||||||
|
setOAuthStatus(state, "No authorization code received")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchange code for tokens
|
||||||
|
tokenReq := &kiroauth.CreateTokenRequest{
|
||||||
|
Code: code,
|
||||||
|
CodeVerifier: codeVerifier,
|
||||||
|
RedirectURI: kiroauth.KiroRedirectURI,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResp, errToken := socialClient.CreateToken(ctx, tokenReq)
|
||||||
|
if errToken != nil {
|
||||||
|
log.Errorf("Failed to exchange code for tokens: %v", errToken)
|
||||||
|
setOAuthStatus(state, "Failed to exchange code for tokens")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the token
|
||||||
|
expiresIn := tokenResp.ExpiresIn
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
expiresIn = 3600
|
||||||
|
}
|
||||||
|
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||||
|
email := kiroauth.ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||||
|
|
||||||
|
idPart := kiroauth.SanitizeEmailForFilename(email)
|
||||||
|
if idPart == "" {
|
||||||
|
idPart = fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
fileName := fmt.Sprintf("kiro-%s-%s.json", strings.ToLower(provider), idPart)
|
||||||
|
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: "kiro",
|
||||||
|
FileName: fileName,
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "kiro",
|
||||||
|
"access_token": tokenResp.AccessToken,
|
||||||
|
"refresh_token": tokenResp.RefreshToken,
|
||||||
|
"profile_arn": tokenResp.ProfileArn,
|
||||||
|
"expires_at": expiresAt.Format(time.RFC3339),
|
||||||
|
"auth_method": "social",
|
||||||
|
"provider": provider,
|
||||||
|
"email": email,
|
||||||
|
"last_refresh": now.Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
|
if errSave != nil {
|
||||||
|
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||||
|
setOAuthStatus(state, "Failed to save authentication tokens")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||||
|
if email != "" {
|
||||||
|
fmt.Printf("Authenticated as: %s\n", email)
|
||||||
|
}
|
||||||
|
deleteOAuthStatus(state)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
setOAuthStatus(state, "")
|
||||||
|
c.JSON(200, gin.H{"status": "ok", "state": state, "method": "social"})
|
||||||
|
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid method, use 'aws', 'google', or 'github'"})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateKiroPKCE generates PKCE code verifier and challenge for Kiro OAuth.
|
||||||
|
func generateKiroPKCE() (verifier, challenge string, err error) {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
if _, err := io.ReadFull(rand.Reader, b); err != nil {
|
||||||
|
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||||
|
}
|
||||||
|
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||||
|
|
||||||
|
h := sha256.Sum256([]byte(verifier))
|
||||||
|
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||||
|
|
||||||
|
return verifier, challenge, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPI/releases/latest"
|
latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPIPlus/releases/latest"
|
||||||
latestReleaseUserAgent = "CLIProxyAPI"
|
latestReleaseUserAgent = "CLIProxyAPIPlus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *Handler) GetConfig(c *gin.Context) {
|
func (h *Handler) GetConfig(c *gin.Context) {
|
||||||
|
|||||||
@@ -3,8 +3,11 @@ package amp
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -67,7 +70,15 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
|||||||
// Modify incoming responses to handle gzip without Content-Encoding
|
// Modify incoming responses to handle gzip without Content-Encoding
|
||||||
// This addresses the same issue as inline handler gzip handling, but at the proxy level
|
// This addresses the same issue as inline handler gzip handling, but at the proxy level
|
||||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||||
// Only process successful responses
|
// Log upstream error responses for diagnostics (502, 503, etc.)
|
||||||
|
// These are NOT proxy connection errors - the upstream responded with an error status
|
||||||
|
if resp.StatusCode >= 500 {
|
||||||
|
log.Errorf("amp upstream responded with error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
|
||||||
|
} else if resp.StatusCode >= 400 {
|
||||||
|
log.Warnf("amp upstream responded with client error [%d] for %s %s", resp.StatusCode, resp.Request.Method, resp.Request.URL.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only process successful responses for gzip decompression
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -151,9 +162,29 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error handler for proxy failures
|
// Error handler for proxy failures with detailed error classification for diagnostics
|
||||||
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||||
log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err)
|
// Classify the error type for better diagnostics
|
||||||
|
var errType string
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
errType = "timeout"
|
||||||
|
} else if errors.Is(err, context.Canceled) {
|
||||||
|
errType = "canceled"
|
||||||
|
} else if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
|
errType = "dial_timeout"
|
||||||
|
} else if _, ok := err.(net.Error); ok {
|
||||||
|
errType = "network_error"
|
||||||
|
} else {
|
||||||
|
errType = "connection_error"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't log as error for context canceled - it's usually client closing connection
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
log.Debugf("amp upstream proxy [%s]: client canceled request for %s %s", errType, req.Method, req.URL.Path)
|
||||||
|
} else {
|
||||||
|
log.Errorf("amp upstream proxy error [%s] for %s %s: %v", errType, req.Method, req.URL.Path, err)
|
||||||
|
}
|
||||||
|
|
||||||
rw.Header().Set("Content-Type", "application/json")
|
rw.Header().Set("Content-Type", "application/json")
|
||||||
rw.WriteHeader(http.StatusBadGateway)
|
rw.WriteHeader(http.StatusBadGateway)
|
||||||
_, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`))
|
_, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`))
|
||||||
|
|||||||
@@ -29,15 +29,71 @@ func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
|
||||||
|
|
||||||
|
func looksLikeSSEChunk(data []byte) bool {
|
||||||
|
// Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered.
|
||||||
|
// Heuristics are intentionally simple and cheap.
|
||||||
|
return bytes.Contains(data, []byte("data:")) ||
|
||||||
|
bytes.Contains(data, []byte("event:")) ||
|
||||||
|
bytes.Contains(data, []byte("message_start")) ||
|
||||||
|
bytes.Contains(data, []byte("message_delta")) ||
|
||||||
|
bytes.Contains(data, []byte("content_block_start")) ||
|
||||||
|
bytes.Contains(data, []byte("content_block_delta")) ||
|
||||||
|
bytes.Contains(data, []byte("content_block_stop")) ||
|
||||||
|
bytes.Contains(data, []byte("\n\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rw *ResponseRewriter) enableStreaming(reason string) error {
|
||||||
|
if rw.isStreaming {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rw.isStreaming = true
|
||||||
|
|
||||||
|
// Flush any previously buffered data to avoid reordering or data loss.
|
||||||
|
if rw.body != nil && rw.body.Len() > 0 {
|
||||||
|
buf := rw.body.Bytes()
|
||||||
|
// Copy before Reset() to keep bytes stable.
|
||||||
|
toFlush := make([]byte, len(buf))
|
||||||
|
copy(toFlush, buf)
|
||||||
|
rw.body.Reset()
|
||||||
|
|
||||||
|
if _, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(toFlush)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("amp response rewriter: switched to streaming (%s)", reason)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Write intercepts response writes and buffers them for model name replacement
|
// Write intercepts response writes and buffers them for model name replacement
|
||||||
func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
||||||
// Detect streaming on first write
|
// Detect streaming on first write (header-based)
|
||||||
if rw.body.Len() == 0 && !rw.isStreaming {
|
if !rw.isStreaming && rw.body.Len() == 0 {
|
||||||
contentType := rw.Header().Get("Content-Type")
|
contentType := rw.Header().Get("Content-Type")
|
||||||
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
|
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
|
||||||
strings.Contains(contentType, "stream")
|
strings.Contains(contentType, "stream")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !rw.isStreaming {
|
||||||
|
// Content-based fallback: detect SSE-like chunks even if Content-Type is missing/wrong.
|
||||||
|
if looksLikeSSEChunk(data) {
|
||||||
|
if err := rw.enableStreaming("sse heuristic"); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
} else if rw.body.Len()+len(data) > maxBufferedResponseBytes {
|
||||||
|
// Safety cap: avoid unbounded buffering on large responses.
|
||||||
|
log.Warnf("amp response rewriter: buffer exceeded %d bytes, switching to streaming", maxBufferedResponseBytes)
|
||||||
|
if err := rw.enableStreaming("buffer limit"); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if rw.isStreaming {
|
if rw.isStreaming {
|
||||||
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|||||||
@@ -349,6 +349,12 @@ func (s *Server) setupRoutes() {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Event logging endpoint - handles Claude Code telemetry requests
|
||||||
|
// Returns 200 OK to prevent 404 errors in logs
|
||||||
|
s.engine.POST("/api/event_logging/batch", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||||
|
})
|
||||||
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
|
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
|
||||||
|
|
||||||
// OAuth callback endpoints (reuse main server port)
|
// OAuth callback endpoints (reuse main server port)
|
||||||
@@ -415,6 +421,18 @@ func (s *Server) setupRoutes() {
|
|||||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
s.engine.GET("/kiro/callback", func(c *gin.Context) {
|
||||||
|
code := c.Query("code")
|
||||||
|
state := c.Query("state")
|
||||||
|
errStr := c.Query("error")
|
||||||
|
if state != "" {
|
||||||
|
file := fmt.Sprintf("%s/.oauth-kiro-%s.oauth", s.cfg.AuthDir, state)
|
||||||
|
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||||
|
}
|
||||||
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||||
|
})
|
||||||
|
|
||||||
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -581,6 +599,7 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||||
|
mgmt.GET("/kiro-auth-url", s.mgmt.RequestKiroToken)
|
||||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -923,7 +942,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
for _, p := range cfg.OpenAICompatibility {
|
for _, p := range cfg.OpenAICompatibility {
|
||||||
providerNames = append(providerNames, p.Name)
|
providerNames = append(providerNames, p.Name)
|
||||||
}
|
}
|
||||||
s.handlers.OpenAICompatProviders = providerNames
|
s.handlers.SetOpenAICompatProviders(providerNames)
|
||||||
|
|
||||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||||
|
|
||||||
|
|||||||
@@ -242,6 +242,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
|||||||
platformURL = "https://console.anthropic.com/"
|
platformURL = "https://console.anthropic.com/"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate platformURL to prevent XSS - only allow http/https URLs
|
||||||
|
if !isValidURL(platformURL) {
|
||||||
|
platformURL = "https://console.anthropic.com/"
|
||||||
|
}
|
||||||
|
|
||||||
// Generate success page HTML with dynamic content
|
// Generate success page HTML with dynamic content
|
||||||
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
|
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
|
||||||
|
|
||||||
@@ -251,6 +256,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isValidURL checks if the URL is a valid http/https URL to prevent XSS
|
||||||
|
func isValidURL(urlStr string) bool {
|
||||||
|
urlStr = strings.TrimSpace(urlStr)
|
||||||
|
return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://")
|
||||||
|
}
|
||||||
|
|
||||||
// generateSuccessHTML creates the HTML content for the success page.
|
// generateSuccessHTML creates the HTML content for the success page.
|
||||||
// It customizes the page based on whether additional setup is required
|
// It customizes the page based on whether additional setup is required
|
||||||
// and includes a link to the platform.
|
// and includes a link to the platform.
|
||||||
|
|||||||
@@ -239,6 +239,11 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
|||||||
platformURL = "https://platform.openai.com"
|
platformURL = "https://platform.openai.com"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate platformURL to prevent XSS - only allow http/https URLs
|
||||||
|
if !isValidURL(platformURL) {
|
||||||
|
platformURL = "https://platform.openai.com"
|
||||||
|
}
|
||||||
|
|
||||||
// Generate success page HTML with dynamic content
|
// Generate success page HTML with dynamic content
|
||||||
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
|
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
|
||||||
|
|
||||||
@@ -248,6 +253,12 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isValidURL checks if the URL is a valid http/https URL to prevent XSS
|
||||||
|
func isValidURL(urlStr string) bool {
|
||||||
|
urlStr = strings.TrimSpace(urlStr)
|
||||||
|
return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://")
|
||||||
|
}
|
||||||
|
|
||||||
// generateSuccessHTML creates the HTML content for the success page.
|
// generateSuccessHTML creates the HTML content for the success page.
|
||||||
// It customizes the page based on whether additional setup is required
|
// It customizes the page based on whether additional setup is required
|
||||||
// and includes a link to the platform.
|
// and includes a link to the platform.
|
||||||
|
|||||||
225
internal/auth/copilot/copilot_auth.go
Normal file
225
internal/auth/copilot/copilot_auth.go
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
// Package copilot provides authentication and token management for GitHub Copilot API.
|
||||||
|
// It handles the OAuth2 device flow for secure authentication with the Copilot API.
|
||||||
|
package copilot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// copilotAPITokenURL is the endpoint for getting Copilot API tokens from GitHub token.
|
||||||
|
copilotAPITokenURL = "https://api.github.com/copilot_internal/v2/token"
|
||||||
|
// copilotAPIEndpoint is the base URL for making API requests.
|
||||||
|
copilotAPIEndpoint = "https://api.githubcopilot.com"
|
||||||
|
|
||||||
|
// Common HTTP header values for Copilot API requests.
|
||||||
|
copilotUserAgent = "GithubCopilot/1.0"
|
||||||
|
copilotEditorVersion = "vscode/1.100.0"
|
||||||
|
copilotPluginVersion = "copilot/1.300.0"
|
||||||
|
copilotIntegrationID = "vscode-chat"
|
||||||
|
copilotOpenAIIntent = "conversation-panel"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CopilotAPIToken represents the Copilot API token response.
|
||||||
|
type CopilotAPIToken struct {
|
||||||
|
// Token is the JWT token for authenticating with the Copilot API.
|
||||||
|
Token string `json:"token"`
|
||||||
|
// ExpiresAt is the Unix timestamp when the token expires.
|
||||||
|
ExpiresAt int64 `json:"expires_at"`
|
||||||
|
// Endpoints contains the available API endpoints.
|
||||||
|
Endpoints struct {
|
||||||
|
API string `json:"api"`
|
||||||
|
Proxy string `json:"proxy"`
|
||||||
|
OriginTracker string `json:"origin-tracker"`
|
||||||
|
Telemetry string `json:"telemetry"`
|
||||||
|
} `json:"endpoints,omitempty"`
|
||||||
|
// ErrorDetails contains error information if the request failed.
|
||||||
|
ErrorDetails *struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
DocumentationURL string `json:"documentation_url"`
|
||||||
|
} `json:"error_details,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopilotAuth handles GitHub Copilot authentication flow.
|
||||||
|
// It provides methods for device flow authentication and token management.
|
||||||
|
type CopilotAuth struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
deviceClient *DeviceFlowClient
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCopilotAuth creates a new CopilotAuth service instance.
|
||||||
|
// It initializes an HTTP client with proxy settings from the provided configuration.
|
||||||
|
func NewCopilotAuth(cfg *config.Config) *CopilotAuth {
|
||||||
|
return &CopilotAuth{
|
||||||
|
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 30 * time.Second}),
|
||||||
|
deviceClient: NewDeviceFlowClient(cfg),
|
||||||
|
cfg: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartDeviceFlow initiates the device flow authentication.
|
||||||
|
// Returns the device code response containing the user code and verification URI.
|
||||||
|
func (c *CopilotAuth) StartDeviceFlow(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||||
|
return c.deviceClient.RequestDeviceCode(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitForAuthorization polls for user authorization and returns the auth bundle.
|
||||||
|
func (c *CopilotAuth) WaitForAuthorization(ctx context.Context, deviceCode *DeviceCodeResponse) (*CopilotAuthBundle, error) {
|
||||||
|
tokenData, err := c.deviceClient.PollForToken(ctx, deviceCode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch the GitHub username
|
||||||
|
username, err := c.deviceClient.FetchUserInfo(ctx, tokenData.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("copilot: failed to fetch user info: %v", err)
|
||||||
|
username = "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CopilotAuthBundle{
|
||||||
|
TokenData: tokenData,
|
||||||
|
Username: username,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCopilotAPIToken exchanges a GitHub access token for a Copilot API token.
|
||||||
|
// This token is used to make authenticated requests to the Copilot API.
|
||||||
|
func (c *CopilotAuth) GetCopilotAPIToken(ctx context.Context, githubAccessToken string) (*CopilotAPIToken, error) {
|
||||||
|
if githubAccessToken == "" {
|
||||||
|
return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("github access token is empty"))
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotAPITokenURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Authorization", "token "+githubAccessToken)
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("User-Agent", copilotUserAgent)
|
||||||
|
req.Header.Set("Editor-Version", copilotEditorVersion)
|
||||||
|
req.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("copilot api token: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isHTTPSuccess(resp.StatusCode) {
|
||||||
|
return nil, NewAuthenticationError(ErrTokenExchangeFailed,
|
||||||
|
fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiToken CopilotAPIToken
|
||||||
|
if err = json.Unmarshal(bodyBytes, &apiToken); err != nil {
|
||||||
|
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if apiToken.Token == "" {
|
||||||
|
return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("empty copilot api token"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &apiToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateToken checks if a GitHub access token is valid by attempting to fetch user info.
|
||||||
|
func (c *CopilotAuth) ValidateToken(ctx context.Context, accessToken string) (bool, string, error) {
|
||||||
|
if accessToken == "" {
|
||||||
|
return false, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
username, err := c.deviceClient.FetchUserInfo(ctx, accessToken)
|
||||||
|
if err != nil {
|
||||||
|
return false, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, username, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTokenStorage creates a new CopilotTokenStorage from auth bundle.
|
||||||
|
func (c *CopilotAuth) CreateTokenStorage(bundle *CopilotAuthBundle) *CopilotTokenStorage {
|
||||||
|
return &CopilotTokenStorage{
|
||||||
|
AccessToken: bundle.TokenData.AccessToken,
|
||||||
|
TokenType: bundle.TokenData.TokenType,
|
||||||
|
Scope: bundle.TokenData.Scope,
|
||||||
|
Username: bundle.Username,
|
||||||
|
Type: "github-copilot",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadAndValidateToken loads a token from storage and validates it.
|
||||||
|
// Returns the storage if valid, or an error if the token is invalid or expired.
|
||||||
|
func (c *CopilotAuth) LoadAndValidateToken(ctx context.Context, storage *CopilotTokenStorage) (bool, error) {
|
||||||
|
if storage == nil || storage.AccessToken == "" {
|
||||||
|
return false, fmt.Errorf("no token available")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we can still use the GitHub token to get a Copilot API token
|
||||||
|
apiToken, err := c.GetCopilotAPIToken(ctx, storage.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the API token is expired
|
||||||
|
if apiToken.ExpiresAt > 0 && time.Now().Unix() >= apiToken.ExpiresAt {
|
||||||
|
return false, fmt.Errorf("copilot api token expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAPIEndpoint returns the Copilot API endpoint URL.
|
||||||
|
func (c *CopilotAuth) GetAPIEndpoint() string {
|
||||||
|
return copilotAPIEndpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
// MakeAuthenticatedRequest creates an authenticated HTTP request to the Copilot API.
|
||||||
|
func (c *CopilotAuth) MakeAuthenticatedRequest(ctx context.Context, method, url string, body io.Reader, apiToken *CopilotAPIToken) (*http.Request, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiToken.Token)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("User-Agent", copilotUserAgent)
|
||||||
|
req.Header.Set("Editor-Version", copilotEditorVersion)
|
||||||
|
req.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
|
||||||
|
req.Header.Set("Openai-Intent", copilotOpenAIIntent)
|
||||||
|
req.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildChatCompletionURL builds the URL for chat completions API.
|
||||||
|
func buildChatCompletionURL() string {
|
||||||
|
return copilotAPIEndpoint + "/chat/completions"
|
||||||
|
}
|
||||||
|
|
||||||
|
// isHTTPSuccess checks if the status code indicates success (2xx).
|
||||||
|
func isHTTPSuccess(statusCode int) bool {
|
||||||
|
return statusCode >= 200 && statusCode < 300
|
||||||
|
}
|
||||||
187
internal/auth/copilot/errors.go
Normal file
187
internal/auth/copilot/errors.go
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
package copilot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OAuthError represents an OAuth-specific error.
|
||||||
|
type OAuthError struct {
|
||||||
|
// Code is the OAuth error code.
|
||||||
|
Code string `json:"error"`
|
||||||
|
// Description is a human-readable description of the error.
|
||||||
|
Description string `json:"error_description,omitempty"`
|
||||||
|
// URI is a URI identifying a human-readable web page with information about the error.
|
||||||
|
URI string `json:"error_uri,omitempty"`
|
||||||
|
// StatusCode is the HTTP status code associated with the error.
|
||||||
|
StatusCode int `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns a string representation of the OAuth error.
|
||||||
|
func (e *OAuthError) Error() string {
|
||||||
|
if e.Description != "" {
|
||||||
|
return fmt.Sprintf("OAuth error %s: %s", e.Code, e.Description)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("OAuth error: %s", e.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOAuthError creates a new OAuth error with the specified code, description, and status code.
|
||||||
|
func NewOAuthError(code, description string, statusCode int) *OAuthError {
|
||||||
|
return &OAuthError{
|
||||||
|
Code: code,
|
||||||
|
Description: description,
|
||||||
|
StatusCode: statusCode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthenticationError represents authentication-related errors.
|
||||||
|
type AuthenticationError struct {
|
||||||
|
// Type is the type of authentication error.
|
||||||
|
Type string `json:"type"`
|
||||||
|
// Message is a human-readable message describing the error.
|
||||||
|
Message string `json:"message"`
|
||||||
|
// Code is the HTTP status code associated with the error.
|
||||||
|
Code int `json:"code"`
|
||||||
|
// Cause is the underlying error that caused this authentication error.
|
||||||
|
Cause error `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns a string representation of the authentication error.
|
||||||
|
func (e *AuthenticationError) Error() string {
|
||||||
|
if e.Cause != nil {
|
||||||
|
return fmt.Sprintf("%s: %s (caused by: %v)", e.Type, e.Message, e.Cause)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s: %s", e.Type, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap returns the underlying cause of the error.
|
||||||
|
func (e *AuthenticationError) Unwrap() error {
|
||||||
|
return e.Cause
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common authentication error types for GitHub Copilot device flow.
|
||||||
|
var (
|
||||||
|
// ErrDeviceCodeFailed represents an error when requesting the device code fails.
|
||||||
|
ErrDeviceCodeFailed = &AuthenticationError{
|
||||||
|
Type: "device_code_failed",
|
||||||
|
Message: "Failed to request device code from GitHub",
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrDeviceCodeExpired represents an error when the device code has expired.
|
||||||
|
ErrDeviceCodeExpired = &AuthenticationError{
|
||||||
|
Type: "device_code_expired",
|
||||||
|
Message: "Device code has expired. Please try again.",
|
||||||
|
Code: http.StatusGone,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrAuthorizationPending represents a pending authorization state (not an error, used for polling).
|
||||||
|
ErrAuthorizationPending = &AuthenticationError{
|
||||||
|
Type: "authorization_pending",
|
||||||
|
Message: "Authorization is pending. Waiting for user to authorize.",
|
||||||
|
Code: http.StatusAccepted,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrSlowDown represents a request to slow down polling.
|
||||||
|
ErrSlowDown = &AuthenticationError{
|
||||||
|
Type: "slow_down",
|
||||||
|
Message: "Polling too frequently. Slowing down.",
|
||||||
|
Code: http.StatusTooManyRequests,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrAccessDenied represents an error when the user denies authorization.
|
||||||
|
ErrAccessDenied = &AuthenticationError{
|
||||||
|
Type: "access_denied",
|
||||||
|
Message: "User denied authorization",
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrTokenExchangeFailed represents an error when token exchange fails.
|
||||||
|
ErrTokenExchangeFailed = &AuthenticationError{
|
||||||
|
Type: "token_exchange_failed",
|
||||||
|
Message: "Failed to exchange device code for access token",
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrPollingTimeout represents an error when polling times out.
|
||||||
|
ErrPollingTimeout = &AuthenticationError{
|
||||||
|
Type: "polling_timeout",
|
||||||
|
Message: "Timeout waiting for user authorization",
|
||||||
|
Code: http.StatusRequestTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrUserInfoFailed represents an error when fetching user info fails.
|
||||||
|
ErrUserInfoFailed = &AuthenticationError{
|
||||||
|
Type: "user_info_failed",
|
||||||
|
Message: "Failed to fetch GitHub user information",
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewAuthenticationError creates a new authentication error with a cause based on a base error.
|
||||||
|
func NewAuthenticationError(baseErr *AuthenticationError, cause error) *AuthenticationError {
|
||||||
|
return &AuthenticationError{
|
||||||
|
Type: baseErr.Type,
|
||||||
|
Message: baseErr.Message,
|
||||||
|
Code: baseErr.Code,
|
||||||
|
Cause: cause,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAuthenticationError checks if an error is an authentication error.
|
||||||
|
func IsAuthenticationError(err error) bool {
|
||||||
|
var authenticationError *AuthenticationError
|
||||||
|
ok := errors.As(err, &authenticationError)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsOAuthError checks if an error is an OAuth error.
|
||||||
|
func IsOAuthError(err error) bool {
|
||||||
|
var oAuthError *OAuthError
|
||||||
|
ok := errors.As(err, &oAuthError)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserFriendlyMessage returns a user-friendly error message based on the error type.
|
||||||
|
func GetUserFriendlyMessage(err error) string {
|
||||||
|
var authErr *AuthenticationError
|
||||||
|
if errors.As(err, &authErr) {
|
||||||
|
switch authErr.Type {
|
||||||
|
case "device_code_failed":
|
||||||
|
return "Failed to start GitHub authentication. Please check your network connection and try again."
|
||||||
|
case "device_code_expired":
|
||||||
|
return "The authentication code has expired. Please try again."
|
||||||
|
case "authorization_pending":
|
||||||
|
return "Waiting for you to authorize the application on GitHub."
|
||||||
|
case "slow_down":
|
||||||
|
return "Please wait a moment before trying again."
|
||||||
|
case "access_denied":
|
||||||
|
return "Authentication was cancelled or denied."
|
||||||
|
case "token_exchange_failed":
|
||||||
|
return "Failed to complete authentication. Please try again."
|
||||||
|
case "polling_timeout":
|
||||||
|
return "Authentication timed out. Please try again."
|
||||||
|
case "user_info_failed":
|
||||||
|
return "Failed to get your GitHub account information. Please try again."
|
||||||
|
default:
|
||||||
|
return "Authentication failed. Please try again."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var oauthErr *OAuthError
|
||||||
|
if errors.As(err, &oauthErr) {
|
||||||
|
switch oauthErr.Code {
|
||||||
|
case "access_denied":
|
||||||
|
return "Authentication was cancelled or denied."
|
||||||
|
case "invalid_request":
|
||||||
|
return "Invalid authentication request. Please try again."
|
||||||
|
case "server_error":
|
||||||
|
return "GitHub server error. Please try again later."
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("Authentication failed: %s", oauthErr.Description)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "An unexpected error occurred. Please try again."
|
||||||
|
}
|
||||||
255
internal/auth/copilot/oauth.go
Normal file
255
internal/auth/copilot/oauth.go
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
package copilot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// copilotClientID is GitHub's Copilot CLI OAuth client ID.
|
||||||
|
copilotClientID = "Iv1.b507a08c87ecfe98"
|
||||||
|
// copilotDeviceCodeURL is the endpoint for requesting device codes.
|
||||||
|
copilotDeviceCodeURL = "https://github.com/login/device/code"
|
||||||
|
// copilotTokenURL is the endpoint for exchanging device codes for tokens.
|
||||||
|
copilotTokenURL = "https://github.com/login/oauth/access_token"
|
||||||
|
// copilotUserInfoURL is the endpoint for fetching GitHub user information.
|
||||||
|
copilotUserInfoURL = "https://api.github.com/user"
|
||||||
|
// defaultPollInterval is the default interval for polling token endpoint.
|
||||||
|
defaultPollInterval = 5 * time.Second
|
||||||
|
// maxPollDuration is the maximum time to wait for user authorization.
|
||||||
|
maxPollDuration = 15 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeviceFlowClient handles the OAuth2 device flow for GitHub Copilot.
|
||||||
|
type DeviceFlowClient struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDeviceFlowClient creates a new device flow client.
|
||||||
|
func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
|
||||||
|
client := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
if cfg != nil {
|
||||||
|
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||||
|
}
|
||||||
|
return &DeviceFlowClient{
|
||||||
|
httpClient: client,
|
||||||
|
cfg: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestDeviceCode initiates the device flow by requesting a device code from GitHub.
|
||||||
|
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("client_id", copilotClientID)
|
||||||
|
data.Set("scope", "user:email")
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotDeviceCodeURL, strings.NewReader(data.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewAuthenticationError(ErrDeviceCodeFailed, err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewAuthenticationError(ErrDeviceCodeFailed, err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("copilot device code: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if !isHTTPSuccess(resp.StatusCode) {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, NewAuthenticationError(ErrDeviceCodeFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var deviceCode DeviceCodeResponse
|
||||||
|
if err = json.NewDecoder(resp.Body).Decode(&deviceCode); err != nil {
|
||||||
|
return nil, NewAuthenticationError(ErrDeviceCodeFailed, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &deviceCode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PollForToken polls the token endpoint until the user authorizes or the device code expires.
|
||||||
|
func (c *DeviceFlowClient) PollForToken(ctx context.Context, deviceCode *DeviceCodeResponse) (*CopilotTokenData, error) {
|
||||||
|
if deviceCode == nil {
|
||||||
|
return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("device code is nil"))
|
||||||
|
}
|
||||||
|
|
||||||
|
interval := time.Duration(deviceCode.Interval) * time.Second
|
||||||
|
if interval < defaultPollInterval {
|
||||||
|
interval = defaultPollInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(maxPollDuration)
|
||||||
|
if deviceCode.ExpiresIn > 0 {
|
||||||
|
codeDeadline := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second)
|
||||||
|
if codeDeadline.Before(deadline) {
|
||||||
|
deadline = codeDeadline
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, NewAuthenticationError(ErrPollingTimeout, ctx.Err())
|
||||||
|
case <-ticker.C:
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
return nil, ErrPollingTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := c.exchangeDeviceCode(ctx, deviceCode.DeviceCode)
|
||||||
|
if err != nil {
|
||||||
|
var authErr *AuthenticationError
|
||||||
|
if errors.As(err, &authErr) {
|
||||||
|
switch authErr.Type {
|
||||||
|
case ErrAuthorizationPending.Type:
|
||||||
|
// Continue polling
|
||||||
|
continue
|
||||||
|
case ErrSlowDown.Type:
|
||||||
|
// Increase interval and continue
|
||||||
|
interval += 5 * time.Second
|
||||||
|
ticker.Reset(interval)
|
||||||
|
continue
|
||||||
|
case ErrDeviceCodeExpired.Type:
|
||||||
|
return nil, err
|
||||||
|
case ErrAccessDenied.Type:
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// exchangeDeviceCode attempts to exchange the device code for an access token.
|
||||||
|
func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode string) (*CopilotTokenData, error) {
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("client_id", copilotClientID)
|
||||||
|
data.Set("device_code", deviceCode)
|
||||||
|
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, copilotTokenURL, strings.NewReader(data.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("copilot token exchange: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitHub returns 200 for both success and error cases in device flow
|
||||||
|
// Check for OAuth error response first
|
||||||
|
var oauthResp struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
ErrorDescription string `json:"error_description"`
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.Unmarshal(bodyBytes, &oauthResp); err != nil {
|
||||||
|
return nil, NewAuthenticationError(ErrTokenExchangeFailed, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if oauthResp.Error != "" {
|
||||||
|
switch oauthResp.Error {
|
||||||
|
case "authorization_pending":
|
||||||
|
return nil, ErrAuthorizationPending
|
||||||
|
case "slow_down":
|
||||||
|
return nil, ErrSlowDown
|
||||||
|
case "expired_token":
|
||||||
|
return nil, ErrDeviceCodeExpired
|
||||||
|
case "access_denied":
|
||||||
|
return nil, ErrAccessDenied
|
||||||
|
default:
|
||||||
|
return nil, NewOAuthError(oauthResp.Error, oauthResp.ErrorDescription, resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if oauthResp.AccessToken == "" {
|
||||||
|
return nil, NewAuthenticationError(ErrTokenExchangeFailed, fmt.Errorf("empty access token"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CopilotTokenData{
|
||||||
|
AccessToken: oauthResp.AccessToken,
|
||||||
|
TokenType: oauthResp.TokenType,
|
||||||
|
Scope: oauthResp.Scope,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchUserInfo retrieves the GitHub username for the authenticated user.
|
||||||
|
func (c *DeviceFlowClient) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
|
||||||
|
if accessToken == "" {
|
||||||
|
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("access token is empty"))
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, copilotUserInfoURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
req.Header.Set("User-Agent", "CLIProxyAPI")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("copilot user info: close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if !isHTTPSuccess(resp.StatusCode) {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("status %d: %s", resp.StatusCode, string(bodyBytes)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var userInfo struct {
|
||||||
|
Login string `json:"login"`
|
||||||
|
}
|
||||||
|
if err = json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
|
||||||
|
return "", NewAuthenticationError(ErrUserInfoFailed, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if userInfo.Login == "" {
|
||||||
|
return "", NewAuthenticationError(ErrUserInfoFailed, fmt.Errorf("empty username"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return userInfo.Login, nil
|
||||||
|
}
|
||||||
93
internal/auth/copilot/token.go
Normal file
93
internal/auth/copilot/token.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
// Package copilot provides authentication and token management functionality
|
||||||
|
// for GitHub Copilot AI services. It handles OAuth2 device flow token storage,
|
||||||
|
// serialization, and retrieval for maintaining authenticated sessions with the Copilot API.
|
||||||
|
package copilot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CopilotTokenStorage stores OAuth2 token information for GitHub Copilot API authentication.
|
||||||
|
// It maintains compatibility with the existing auth system while adding Copilot-specific fields
|
||||||
|
// for managing access tokens and user account information.
|
||||||
|
type CopilotTokenStorage struct {
|
||||||
|
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
// TokenType is the type of token, typically "bearer".
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
// Scope is the OAuth2 scope granted to the token.
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
// ExpiresAt is the timestamp when the access token expires (if provided).
|
||||||
|
ExpiresAt string `json:"expires_at,omitempty"`
|
||||||
|
// Username is the GitHub username associated with this token.
|
||||||
|
Username string `json:"username"`
|
||||||
|
// Type indicates the authentication provider type, always "github-copilot" for this storage.
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopilotTokenData holds the raw OAuth token response from GitHub.
|
||||||
|
type CopilotTokenData struct {
|
||||||
|
// AccessToken is the OAuth2 access token.
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
// TokenType is the type of token, typically "bearer".
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
// Scope is the OAuth2 scope granted to the token.
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopilotAuthBundle bundles authentication data for storage.
|
||||||
|
type CopilotAuthBundle struct {
|
||||||
|
// TokenData contains the OAuth token information.
|
||||||
|
TokenData *CopilotTokenData
|
||||||
|
// Username is the GitHub username.
|
||||||
|
Username string
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceCodeResponse represents GitHub's device code response.
|
||||||
|
type DeviceCodeResponse struct {
|
||||||
|
// DeviceCode is the device verification code.
|
||||||
|
DeviceCode string `json:"device_code"`
|
||||||
|
// UserCode is the code the user must enter at the verification URI.
|
||||||
|
UserCode string `json:"user_code"`
|
||||||
|
// VerificationURI is the URL where the user should enter the code.
|
||||||
|
VerificationURI string `json:"verification_uri"`
|
||||||
|
// ExpiresIn is the number of seconds until the device code expires.
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
// Interval is the minimum number of seconds to wait between polling requests.
|
||||||
|
Interval int `json:"interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveTokenToFile serializes the Copilot token storage to a JSON file.
|
||||||
|
// This method creates the necessary directory structure and writes the token
|
||||||
|
// data in JSON format to the specified file path for persistent storage.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - authFilePath: The full path where the token file should be saved
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: An error if the operation fails, nil otherwise
|
||||||
|
func (ts *CopilotTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||||
|
misc.LogSavingCredentials(authFilePath)
|
||||||
|
ts.Type = "github-copilot"
|
||||||
|
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||||
|
return fmt.Errorf("failed to create directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Create(authFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create token file: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = f.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||||
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -28,10 +29,21 @@ const (
|
|||||||
iFlowAPIKeyEndpoint = "https://platform.iflow.cn/api/openapi/apikey"
|
iFlowAPIKeyEndpoint = "https://platform.iflow.cn/api/openapi/apikey"
|
||||||
|
|
||||||
// Client credentials provided by iFlow for the Code Assist integration.
|
// Client credentials provided by iFlow for the Code Assist integration.
|
||||||
iFlowOAuthClientID = "10009311001"
|
iFlowOAuthClientID = "10009311001"
|
||||||
iFlowOAuthClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
|
// Default client secret (can be overridden via IFLOW_CLIENT_SECRET env var)
|
||||||
|
defaultIFlowClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// getIFlowClientSecret returns the iFlow OAuth client secret.
|
||||||
|
// It first checks the IFLOW_CLIENT_SECRET environment variable,
|
||||||
|
// falling back to the default value if not set.
|
||||||
|
func getIFlowClientSecret() string {
|
||||||
|
if secret := os.Getenv("IFLOW_CLIENT_SECRET"); secret != "" {
|
||||||
|
return secret
|
||||||
|
}
|
||||||
|
return defaultIFlowClientSecret
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultAPIBaseURL is the canonical chat completions endpoint.
|
// DefaultAPIBaseURL is the canonical chat completions endpoint.
|
||||||
const DefaultAPIBaseURL = "https://apis.iflow.cn/v1"
|
const DefaultAPIBaseURL = "https://apis.iflow.cn/v1"
|
||||||
|
|
||||||
@@ -72,7 +84,7 @@ func (ia *IFlowAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectUR
|
|||||||
form.Set("code", code)
|
form.Set("code", code)
|
||||||
form.Set("redirect_uri", redirectURI)
|
form.Set("redirect_uri", redirectURI)
|
||||||
form.Set("client_id", iFlowOAuthClientID)
|
form.Set("client_id", iFlowOAuthClientID)
|
||||||
form.Set("client_secret", iFlowOAuthClientSecret)
|
form.Set("client_secret", getIFlowClientSecret())
|
||||||
|
|
||||||
req, err := ia.newTokenRequest(ctx, form)
|
req, err := ia.newTokenRequest(ctx, form)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -88,7 +100,7 @@ func (ia *IFlowAuth) RefreshTokens(ctx context.Context, refreshToken string) (*I
|
|||||||
form.Set("grant_type", "refresh_token")
|
form.Set("grant_type", "refresh_token")
|
||||||
form.Set("refresh_token", refreshToken)
|
form.Set("refresh_token", refreshToken)
|
||||||
form.Set("client_id", iFlowOAuthClientID)
|
form.Set("client_id", iFlowOAuthClientID)
|
||||||
form.Set("client_secret", iFlowOAuthClientSecret)
|
form.Set("client_secret", getIFlowClientSecret())
|
||||||
|
|
||||||
req, err := ia.newTokenRequest(ctx, form)
|
req, err := ia.newTokenRequest(ctx, form)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -104,7 +116,7 @@ func (ia *IFlowAuth) newTokenRequest(ctx context.Context, form url.Values) (*htt
|
|||||||
return nil, fmt.Errorf("iflow token: create request failed: %w", err)
|
return nil, fmt.Errorf("iflow token: create request failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + iFlowOAuthClientSecret))
|
basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + getIFlowClientSecret()))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
req.Header.Set("Accept", "application/json")
|
req.Header.Set("Accept", "application/json")
|
||||||
req.Header.Set("Authorization", "Basic "+basic)
|
req.Header.Set("Authorization", "Basic "+basic)
|
||||||
|
|||||||
301
internal/auth/kiro/aws.go
Normal file
301
internal/auth/kiro/aws.go
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
// Package kiro provides authentication functionality for AWS CodeWhisperer (Kiro) API.
|
||||||
|
// It includes interfaces and implementations for token storage and authentication methods.
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
||||||
|
type PKCECodes struct {
|
||||||
|
// CodeVerifier is the cryptographically random string used to correlate
|
||||||
|
// the authorization request to the token request
|
||||||
|
CodeVerifier string `json:"code_verifier"`
|
||||||
|
// CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded
|
||||||
|
CodeChallenge string `json:"code_challenge"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroTokenData holds OAuth token information from AWS CodeWhisperer (Kiro)
|
||||||
|
type KiroTokenData struct {
|
||||||
|
// AccessToken is the OAuth2 access token for API access
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
// RefreshToken is used to obtain new access tokens
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
// ProfileArn is the AWS CodeWhisperer profile ARN
|
||||||
|
ProfileArn string `json:"profileArn"`
|
||||||
|
// ExpiresAt is the timestamp when the token expires
|
||||||
|
ExpiresAt string `json:"expiresAt"`
|
||||||
|
// AuthMethod indicates the authentication method used (e.g., "builder-id", "social")
|
||||||
|
AuthMethod string `json:"authMethod"`
|
||||||
|
// Provider indicates the OAuth provider (e.g., "AWS", "Google")
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
// ClientID is the OIDC client ID (needed for token refresh)
|
||||||
|
ClientID string `json:"clientId,omitempty"`
|
||||||
|
// ClientSecret is the OIDC client secret (needed for token refresh)
|
||||||
|
ClientSecret string `json:"clientSecret,omitempty"`
|
||||||
|
// Email is the user's email address (used for file naming)
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroAuthBundle aggregates authentication data after OAuth flow completion
|
||||||
|
type KiroAuthBundle struct {
|
||||||
|
// TokenData contains the OAuth tokens from the authentication flow
|
||||||
|
TokenData KiroTokenData `json:"token_data"`
|
||||||
|
// LastRefresh is the timestamp of the last token refresh
|
||||||
|
LastRefresh string `json:"last_refresh"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroUsageInfo represents usage information from CodeWhisperer API
|
||||||
|
type KiroUsageInfo struct {
|
||||||
|
// SubscriptionTitle is the subscription plan name (e.g., "KIRO FREE")
|
||||||
|
SubscriptionTitle string `json:"subscription_title"`
|
||||||
|
// CurrentUsage is the current credit usage
|
||||||
|
CurrentUsage float64 `json:"current_usage"`
|
||||||
|
// UsageLimit is the maximum credit limit
|
||||||
|
UsageLimit float64 `json:"usage_limit"`
|
||||||
|
// NextReset is the timestamp of the next usage reset
|
||||||
|
NextReset string `json:"next_reset"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroModel represents a model available through the CodeWhisperer API
|
||||||
|
type KiroModel struct {
|
||||||
|
// ModelID is the unique identifier for the model
|
||||||
|
ModelID string `json:"modelId"`
|
||||||
|
// ModelName is the human-readable name
|
||||||
|
ModelName string `json:"modelName"`
|
||||||
|
// Description is the model description
|
||||||
|
Description string `json:"description"`
|
||||||
|
// RateMultiplier is the credit multiplier for this model
|
||||||
|
RateMultiplier float64 `json:"rateMultiplier"`
|
||||||
|
// RateUnit is the unit for rate calculation (e.g., "credit")
|
||||||
|
RateUnit string `json:"rateUnit"`
|
||||||
|
// MaxInputTokens is the maximum input token limit
|
||||||
|
MaxInputTokens int `json:"maxInputTokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroIDETokenFile is the default path to Kiro IDE's token file
|
||||||
|
const KiroIDETokenFile = ".aws/sso/cache/kiro-auth-token.json"
|
||||||
|
|
||||||
|
// LoadKiroIDEToken loads token data from Kiro IDE's token file.
|
||||||
|
func LoadKiroIDEToken() (*KiroTokenData, error) {
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenPath := filepath.Join(homeDir, KiroIDETokenFile)
|
||||||
|
data, err := os.ReadFile(tokenPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read Kiro IDE token file (%s): %w", tokenPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var token KiroTokenData
|
||||||
|
if err := json.Unmarshal(data, &token); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse Kiro IDE token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if token.AccessToken == "" {
|
||||||
|
return nil, fmt.Errorf("access token is empty in Kiro IDE token file")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadKiroTokenFromPath loads token data from a custom path.
|
||||||
|
// This supports multiple accounts by allowing different token files.
|
||||||
|
func LoadKiroTokenFromPath(tokenPath string) (*KiroTokenData, error) {
|
||||||
|
// Expand ~ to home directory
|
||||||
|
if len(tokenPath) > 0 && tokenPath[0] == '~' {
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||||
|
}
|
||||||
|
tokenPath = filepath.Join(homeDir, tokenPath[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(tokenPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read token file (%s): %w", tokenPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var token KiroTokenData
|
||||||
|
if err := json.Unmarshal(data, &token); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if token.AccessToken == "" {
|
||||||
|
return nil, fmt.Errorf("access token is empty in token file")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListKiroTokenFiles lists all Kiro token files in the cache directory.
|
||||||
|
// This supports multiple accounts by finding all token files.
|
||||||
|
func ListKiroTokenFiles() ([]string, error) {
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheDir := filepath.Join(homeDir, ".aws", "sso", "cache")
|
||||||
|
|
||||||
|
// Check if directory exists
|
||||||
|
if _, err := os.Stat(cacheDir); os.IsNotExist(err) {
|
||||||
|
return nil, nil // No token files
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(cacheDir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read cache directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenFiles []string
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := entry.Name()
|
||||||
|
// Look for kiro token files only (avoid matching unrelated AWS SSO cache files)
|
||||||
|
if strings.HasSuffix(name, ".json") && strings.HasPrefix(name, "kiro") {
|
||||||
|
tokenFiles = append(tokenFiles, filepath.Join(cacheDir, name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenFiles, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadAllKiroTokens loads all Kiro tokens from the cache directory.
|
||||||
|
// This supports multiple accounts.
|
||||||
|
func LoadAllKiroTokens() ([]*KiroTokenData, error) {
|
||||||
|
files, err := ListKiroTokenFiles()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokens []*KiroTokenData
|
||||||
|
for _, file := range files {
|
||||||
|
token, err := LoadKiroTokenFromPath(file)
|
||||||
|
if err != nil {
|
||||||
|
// Skip invalid token files
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tokens = append(tokens, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWTClaims represents the claims we care about from a JWT token.
|
||||||
|
// JWT tokens from Kiro/AWS contain user information in the payload.
|
||||||
|
type JWTClaims struct {
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
Sub string `json:"sub,omitempty"`
|
||||||
|
PreferredUser string `json:"preferred_username,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Iss string `json:"iss,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractEmailFromJWT extracts the user's email from a JWT access token.
|
||||||
|
// JWT tokens typically have format: header.payload.signature
|
||||||
|
// The payload is base64url-encoded JSON containing user claims.
|
||||||
|
func ExtractEmailFromJWT(accessToken string) string {
|
||||||
|
if accessToken == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// JWT format: header.payload.signature
|
||||||
|
parts := strings.Split(accessToken, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the payload (second part)
|
||||||
|
payload := parts[1]
|
||||||
|
|
||||||
|
// Add padding if needed (base64url requires padding)
|
||||||
|
switch len(payload) % 4 {
|
||||||
|
case 2:
|
||||||
|
payload += "=="
|
||||||
|
case 3:
|
||||||
|
payload += "="
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := base64.URLEncoding.DecodeString(payload)
|
||||||
|
if err != nil {
|
||||||
|
// Try RawURLEncoding (no padding)
|
||||||
|
decoded, err = base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims JWTClaims
|
||||||
|
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return email if available
|
||||||
|
if claims.Email != "" {
|
||||||
|
return claims.Email
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to preferred_username (some providers use this)
|
||||||
|
if claims.PreferredUser != "" && strings.Contains(claims.PreferredUser, "@") {
|
||||||
|
return claims.PreferredUser
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to sub if it looks like an email
|
||||||
|
if claims.Sub != "" && strings.Contains(claims.Sub, "@") {
|
||||||
|
return claims.Sub
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeEmailForFilename sanitizes an email address for use in a filename.
|
||||||
|
// Replaces special characters with underscores and prevents path traversal attacks.
|
||||||
|
// Also handles URL-encoded characters to prevent encoded path traversal attempts.
|
||||||
|
func SanitizeEmailForFilename(email string) string {
|
||||||
|
if email == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
result := email
|
||||||
|
|
||||||
|
// First, handle URL-encoded path traversal attempts (%2F, %2E, %5C, etc.)
|
||||||
|
// This prevents encoded characters from bypassing the sanitization.
|
||||||
|
// Note: We replace % last to catch any remaining encodings including double-encoding (%252F)
|
||||||
|
result = strings.ReplaceAll(result, "%2F", "_") // /
|
||||||
|
result = strings.ReplaceAll(result, "%2f", "_")
|
||||||
|
result = strings.ReplaceAll(result, "%5C", "_") // \
|
||||||
|
result = strings.ReplaceAll(result, "%5c", "_")
|
||||||
|
result = strings.ReplaceAll(result, "%2E", "_") // .
|
||||||
|
result = strings.ReplaceAll(result, "%2e", "_")
|
||||||
|
result = strings.ReplaceAll(result, "%00", "_") // null byte
|
||||||
|
result = strings.ReplaceAll(result, "%", "_") // Catch remaining % to prevent double-encoding attacks
|
||||||
|
|
||||||
|
// Replace characters that are problematic in filenames
|
||||||
|
// Keep @ and . in middle but replace other special characters
|
||||||
|
for _, char := range []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|", " ", "\x00"} {
|
||||||
|
result = strings.ReplaceAll(result, char, "_")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prevent path traversal: replace leading dots in each path component
|
||||||
|
// This handles cases like "../../../etc/passwd" → "_.._.._.._etc_passwd"
|
||||||
|
parts := strings.Split(result, "_")
|
||||||
|
for i, part := range parts {
|
||||||
|
for strings.HasPrefix(part, ".") {
|
||||||
|
part = "_" + part[1:]
|
||||||
|
}
|
||||||
|
parts[i] = part
|
||||||
|
}
|
||||||
|
result = strings.Join(parts, "_")
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
314
internal/auth/kiro/aws_auth.go
Normal file
314
internal/auth/kiro/aws_auth.go
Normal file
@@ -0,0 +1,314 @@
|
|||||||
|
// Package kiro provides OAuth2 authentication functionality for AWS CodeWhisperer (Kiro) API.
|
||||||
|
// This package implements token loading, refresh, and API communication with CodeWhisperer.
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// awsKiroEndpoint is used for CodeWhisperer management APIs (GetUsageLimits, ListProfiles, etc.)
|
||||||
|
// Note: This is different from the Amazon Q streaming endpoint (q.us-east-1.amazonaws.com)
|
||||||
|
// used in kiro_executor.go for GenerateAssistantResponse. Both endpoints are correct
|
||||||
|
// for their respective API operations.
|
||||||
|
awsKiroEndpoint = "https://codewhisperer.us-east-1.amazonaws.com"
|
||||||
|
defaultTokenFile = "~/.aws/sso/cache/kiro-auth-token.json"
|
||||||
|
targetGetUsage = "AmazonCodeWhispererService.GetUsageLimits"
|
||||||
|
targetListModels = "AmazonCodeWhispererService.ListAvailableModels"
|
||||||
|
targetGenerateChat = "AmazonCodeWhispererStreamingService.GenerateAssistantResponse"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KiroAuth handles AWS CodeWhisperer authentication and API communication.
|
||||||
|
// It provides methods for loading tokens, refreshing expired tokens,
|
||||||
|
// and communicating with the CodeWhisperer API.
|
||||||
|
type KiroAuth struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
endpoint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewKiroAuth creates a new Kiro authentication service.
|
||||||
|
// It initializes the HTTP client with proxy settings from the configuration.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration containing proxy settings
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *KiroAuth: A new Kiro authentication service instance
|
||||||
|
func NewKiroAuth(cfg *config.Config) *KiroAuth {
|
||||||
|
return &KiroAuth{
|
||||||
|
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{Timeout: 120 * time.Second}),
|
||||||
|
endpoint: awsKiroEndpoint,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadTokenFromFile loads token data from a file path.
|
||||||
|
// This method reads and parses the token file, expanding ~ to the home directory.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - tokenFile: Path to the token file (supports ~ expansion)
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *KiroTokenData: The parsed token data
|
||||||
|
// - error: An error if file reading or parsing fails
|
||||||
|
func (k *KiroAuth) LoadTokenFromFile(tokenFile string) (*KiroTokenData, error) {
|
||||||
|
// Expand ~ to home directory
|
||||||
|
if strings.HasPrefix(tokenFile, "~") {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get home directory: %w", err)
|
||||||
|
}
|
||||||
|
tokenFile = filepath.Join(home, tokenFile[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(tokenFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read token file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenData KiroTokenData
|
||||||
|
if err := json.Unmarshal(data, &tokenData); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokenData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTokenExpired checks if the token has expired.
|
||||||
|
// This method parses the expiration timestamp and compares it with the current time.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - tokenData: The token data to check
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - bool: True if the token has expired, false otherwise
|
||||||
|
func (k *KiroAuth) IsTokenExpired(tokenData *KiroTokenData) bool {
|
||||||
|
if tokenData.ExpiresAt == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
// Try alternate format
|
||||||
|
expiresAt, err = time.Parse("2006-01-02T15:04:05.000Z", tokenData.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Now().After(expiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeRequest sends a request to the CodeWhisperer API.
|
||||||
|
// This is an internal method for making authenticated API calls.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - ctx: The context for the request
|
||||||
|
// - target: The API target (e.g., "AmazonCodeWhispererService.GetUsageLimits")
|
||||||
|
// - accessToken: The OAuth access token
|
||||||
|
// - payload: The request payload
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []byte: The response body
|
||||||
|
// - error: An error if the request fails
|
||||||
|
func (k *KiroAuth) makeRequest(ctx context.Context, target string, accessToken string, payload interface{}) ([]byte, error) {
|
||||||
|
jsonBody, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, k.endpoint, strings.NewReader(string(jsonBody)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||||
|
req.Header.Set("x-amz-target", target)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := k.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("failed to close response body: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUsageLimits retrieves usage information from the CodeWhisperer API.
|
||||||
|
// This method fetches the current usage statistics and subscription information.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - ctx: The context for the request
|
||||||
|
// - tokenData: The token data containing access token and profile ARN
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *KiroUsageInfo: The usage information
|
||||||
|
// - error: An error if the request fails
|
||||||
|
func (k *KiroAuth) GetUsageLimits(ctx context.Context, tokenData *KiroTokenData) (*KiroUsageInfo, error) {
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"origin": "AI_EDITOR",
|
||||||
|
"profileArn": tokenData.ProfileArn,
|
||||||
|
"resourceType": "AGENTIC_REQUEST",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := k.makeRequest(ctx, targetGetUsage, tokenData.AccessToken, payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
SubscriptionInfo struct {
|
||||||
|
SubscriptionTitle string `json:"subscriptionTitle"`
|
||||||
|
} `json:"subscriptionInfo"`
|
||||||
|
UsageBreakdownList []struct {
|
||||||
|
CurrentUsageWithPrecision float64 `json:"currentUsageWithPrecision"`
|
||||||
|
UsageLimitWithPrecision float64 `json:"usageLimitWithPrecision"`
|
||||||
|
} `json:"usageBreakdownList"`
|
||||||
|
NextDateReset float64 `json:"nextDateReset"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse usage response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
usage := &KiroUsageInfo{
|
||||||
|
SubscriptionTitle: result.SubscriptionInfo.SubscriptionTitle,
|
||||||
|
NextReset: fmt.Sprintf("%v", result.NextDateReset),
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.UsageBreakdownList) > 0 {
|
||||||
|
usage.CurrentUsage = result.UsageBreakdownList[0].CurrentUsageWithPrecision
|
||||||
|
usage.UsageLimit = result.UsageBreakdownList[0].UsageLimitWithPrecision
|
||||||
|
}
|
||||||
|
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListAvailableModels retrieves available models from the CodeWhisperer API.
|
||||||
|
// This method fetches the list of AI models available for the authenticated user.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - ctx: The context for the request
|
||||||
|
// - tokenData: The token data containing access token and profile ARN
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []*KiroModel: The list of available models
|
||||||
|
// - error: An error if the request fails
|
||||||
|
func (k *KiroAuth) ListAvailableModels(ctx context.Context, tokenData *KiroTokenData) ([]*KiroModel, error) {
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"origin": "AI_EDITOR",
|
||||||
|
"profileArn": tokenData.ProfileArn,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := k.makeRequest(ctx, targetListModels, tokenData.AccessToken, payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Models []struct {
|
||||||
|
ModelID string `json:"modelId"`
|
||||||
|
ModelName string `json:"modelName"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
RateMultiplier float64 `json:"rateMultiplier"`
|
||||||
|
RateUnit string `json:"rateUnit"`
|
||||||
|
TokenLimits struct {
|
||||||
|
MaxInputTokens int `json:"maxInputTokens"`
|
||||||
|
} `json:"tokenLimits"`
|
||||||
|
} `json:"models"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse models response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
models := make([]*KiroModel, 0, len(result.Models))
|
||||||
|
for _, m := range result.Models {
|
||||||
|
models = append(models, &KiroModel{
|
||||||
|
ModelID: m.ModelID,
|
||||||
|
ModelName: m.ModelName,
|
||||||
|
Description: m.Description,
|
||||||
|
RateMultiplier: m.RateMultiplier,
|
||||||
|
RateUnit: m.RateUnit,
|
||||||
|
MaxInputTokens: m.TokenLimits.MaxInputTokens,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTokenStorage creates a new KiroTokenStorage from token data.
|
||||||
|
// This method converts the token data into a storage structure suitable for persistence.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - tokenData: The token data to convert
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *KiroTokenStorage: A new token storage instance
|
||||||
|
func (k *KiroAuth) CreateTokenStorage(tokenData *KiroTokenData) *KiroTokenStorage {
|
||||||
|
return &KiroTokenStorage{
|
||||||
|
AccessToken: tokenData.AccessToken,
|
||||||
|
RefreshToken: tokenData.RefreshToken,
|
||||||
|
ProfileArn: tokenData.ProfileArn,
|
||||||
|
ExpiresAt: tokenData.ExpiresAt,
|
||||||
|
AuthMethod: tokenData.AuthMethod,
|
||||||
|
Provider: tokenData.Provider,
|
||||||
|
LastRefresh: time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateToken checks if the token is valid by making a test API call.
|
||||||
|
// This method verifies the token by attempting to fetch usage limits.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - ctx: The context for the request
|
||||||
|
// - tokenData: The token data to validate
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: An error if the token is invalid
|
||||||
|
func (k *KiroAuth) ValidateToken(ctx context.Context, tokenData *KiroTokenData) error {
|
||||||
|
_, err := k.GetUsageLimits(ctx, tokenData)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTokenStorage updates an existing token storage with new token data.
|
||||||
|
// This method refreshes the token storage with newly obtained access and refresh tokens.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - storage: The existing token storage to update
|
||||||
|
// - tokenData: The new token data to apply
|
||||||
|
func (k *KiroAuth) UpdateTokenStorage(storage *KiroTokenStorage, tokenData *KiroTokenData) {
|
||||||
|
storage.AccessToken = tokenData.AccessToken
|
||||||
|
storage.RefreshToken = tokenData.RefreshToken
|
||||||
|
storage.ProfileArn = tokenData.ProfileArn
|
||||||
|
storage.ExpiresAt = tokenData.ExpiresAt
|
||||||
|
storage.AuthMethod = tokenData.AuthMethod
|
||||||
|
storage.Provider = tokenData.Provider
|
||||||
|
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||||
|
}
|
||||||
161
internal/auth/kiro/aws_test.go
Normal file
161
internal/auth/kiro/aws_test.go
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractEmailFromJWT(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty token",
|
||||||
|
token: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid token format",
|
||||||
|
token: "not.a.valid.jwt",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid token - not base64",
|
||||||
|
token: "xxx.yyy.zzz",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid JWT with email",
|
||||||
|
token: createTestJWT(map[string]any{"email": "test@example.com", "sub": "user123"}),
|
||||||
|
expected: "test@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "JWT without email but with preferred_username",
|
||||||
|
token: createTestJWT(map[string]any{"preferred_username": "user@domain.com", "sub": "user123"}),
|
||||||
|
expected: "user@domain.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "JWT with email-like sub",
|
||||||
|
token: createTestJWT(map[string]any{"sub": "another@test.com"}),
|
||||||
|
expected: "another@test.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "JWT without any email fields",
|
||||||
|
token: createTestJWT(map[string]any{"sub": "user123", "name": "Test User"}),
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ExtractEmailFromJWT(tt.token)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("ExtractEmailFromJWT() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeEmailForFilename(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
email string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty email",
|
||||||
|
email: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Simple email",
|
||||||
|
email: "user@example.com",
|
||||||
|
expected: "user@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Email with space",
|
||||||
|
email: "user name@example.com",
|
||||||
|
expected: "user_name@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Email with special chars",
|
||||||
|
email: "user:name@example.com",
|
||||||
|
expected: "user_name@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Email with multiple special chars",
|
||||||
|
email: "user/name:test@example.com",
|
||||||
|
expected: "user_name_test@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Path traversal attempt",
|
||||||
|
email: "../../../etc/passwd",
|
||||||
|
expected: "_.__.__._etc_passwd",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Path traversal with backslash",
|
||||||
|
email: `..\..\..\..\windows\system32`,
|
||||||
|
expected: "_.__.__.__._windows_system32",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Null byte injection attempt",
|
||||||
|
email: "user\x00@evil.com",
|
||||||
|
expected: "user_@evil.com",
|
||||||
|
},
|
||||||
|
// URL-encoded path traversal tests
|
||||||
|
{
|
||||||
|
name: "URL-encoded slash",
|
||||||
|
email: "user%2Fpath@example.com",
|
||||||
|
expected: "user_path@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL-encoded backslash",
|
||||||
|
email: "user%5Cpath@example.com",
|
||||||
|
expected: "user_path@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL-encoded dot",
|
||||||
|
email: "%2E%2E%2Fetc%2Fpasswd",
|
||||||
|
expected: "___etc_passwd",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL-encoded null",
|
||||||
|
email: "user%00@evil.com",
|
||||||
|
expected: "user_@evil.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Double URL-encoding attack",
|
||||||
|
email: "%252F%252E%252E",
|
||||||
|
expected: "_252F_252E_252E", // % replaced with _, remaining chars preserved (safe)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Mixed case URL-encoding",
|
||||||
|
email: "%2f%2F%5c%5C",
|
||||||
|
expected: "____",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := SanitizeEmailForFilename(tt.email)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("SanitizeEmailForFilename() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// createTestJWT creates a test JWT token with the given claims
|
||||||
|
func createTestJWT(claims map[string]any) string {
|
||||||
|
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"RS256","typ":"JWT"}`))
|
||||||
|
|
||||||
|
payloadBytes, _ := json.Marshal(claims)
|
||||||
|
payload := base64.RawURLEncoding.EncodeToString(payloadBytes)
|
||||||
|
|
||||||
|
signature := base64.RawURLEncoding.EncodeToString([]byte("fake-signature"))
|
||||||
|
|
||||||
|
return header + "." + payload + "." + signature
|
||||||
|
}
|
||||||
296
internal/auth/kiro/oauth.go
Normal file
296
internal/auth/kiro/oauth.go
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
// Package kiro provides OAuth2 authentication for Kiro using native Google login.
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"html"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Kiro auth endpoint
|
||||||
|
kiroAuthEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
||||||
|
|
||||||
|
// Default callback port
|
||||||
|
defaultCallbackPort = 9876
|
||||||
|
|
||||||
|
// Auth timeout
|
||||||
|
authTimeout = 10 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// KiroTokenResponse represents the response from Kiro token endpoint.
|
||||||
|
type KiroTokenResponse struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
ProfileArn string `json:"profileArn"`
|
||||||
|
ExpiresIn int `json:"expiresIn"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroOAuth handles the OAuth flow for Kiro authentication.
|
||||||
|
type KiroOAuth struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewKiroOAuth creates a new Kiro OAuth handler.
|
||||||
|
func NewKiroOAuth(cfg *config.Config) *KiroOAuth {
|
||||||
|
client := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
if cfg != nil {
|
||||||
|
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||||
|
}
|
||||||
|
return &KiroOAuth{
|
||||||
|
httpClient: client,
|
||||||
|
cfg: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateCodeVerifier generates a random code verifier for PKCE.
|
||||||
|
func generateCodeVerifier() (string, error) {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateCodeChallenge generates the code challenge from verifier.
|
||||||
|
func generateCodeChallenge(verifier string) string {
|
||||||
|
h := sha256.Sum256([]byte(verifier))
|
||||||
|
return base64.RawURLEncoding.EncodeToString(h[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateState generates a random state parameter.
|
||||||
|
func generateState() (string, error) {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthResult contains the authorization code and state from callback.
|
||||||
|
type AuthResult struct {
|
||||||
|
Code string
|
||||||
|
State string
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
// startCallbackServer starts a local HTTP server to receive the OAuth callback.
|
||||||
|
func (o *KiroOAuth) startCallbackServer(ctx context.Context, expectedState string) (string, <-chan AuthResult, error) {
|
||||||
|
// Try to find an available port - use localhost like Kiro does
|
||||||
|
listener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", defaultCallbackPort))
|
||||||
|
if err != nil {
|
||||||
|
// Try with dynamic port (RFC 8252 allows dynamic ports for native apps)
|
||||||
|
log.Warnf("kiro oauth: default port %d is busy, falling back to dynamic port", defaultCallbackPort)
|
||||||
|
listener, err = net.Listen("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
port := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
// Use http scheme for local callback server
|
||||||
|
redirectURI := fmt.Sprintf("http://localhost:%d/oauth/callback", port)
|
||||||
|
resultChan := make(chan AuthResult, 1)
|
||||||
|
|
||||||
|
server := &http.Server{
|
||||||
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
errParam := r.URL.Query().Get("error")
|
||||||
|
|
||||||
|
if errParam != "" {
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprintf(w, `<html><body><h1>Login Failed</h1><p>%s</p><p>You can close this window.</p></body></html>`, html.EscapeString(errParam))
|
||||||
|
resultChan <- AuthResult{Error: errParam}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if state != expectedState {
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprint(w, `<html><body><h1>Login Failed</h1><p>Invalid state parameter</p><p>You can close this window.</p></body></html>`)
|
||||||
|
resultChan <- AuthResult{Error: "state mismatch"}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
fmt.Fprint(w, `<html><body><h1>Login Successful!</h1><p>You can close this window and return to the terminal.</p></body></html>`)
|
||||||
|
resultChan <- AuthResult{Code: code, State: state}
|
||||||
|
})
|
||||||
|
|
||||||
|
server.Handler = mux
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Debugf("callback server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-time.After(authTimeout):
|
||||||
|
case <-resultChan:
|
||||||
|
}
|
||||||
|
_ = server.Shutdown(context.Background())
|
||||||
|
}()
|
||||||
|
|
||||||
|
return redirectURI, resultChan, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithBuilderID performs OAuth login with AWS Builder ID using device code flow.
|
||||||
|
func (o *KiroOAuth) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
ssoClient := NewSSOOIDCClient(o.cfg)
|
||||||
|
return ssoClient.LoginWithBuilderID(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// exchangeCodeForToken exchanges the authorization code for tokens.
|
||||||
|
func (o *KiroOAuth) exchangeCodeForToken(ctx context.Context, code, codeVerifier, redirectURI string) (*KiroTokenData, error) {
|
||||||
|
payload := map[string]string{
|
||||||
|
"code": code,
|
||||||
|
"code_verifier": codeVerifier,
|
||||||
|
"redirect_uri": redirectURI,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenURL := kiroAuthEndpoint + "/oauth/token"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
|
||||||
|
|
||||||
|
resp, err := o.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("token request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp KiroTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate ExpiresIn - use default 1 hour if invalid
|
||||||
|
expiresIn := tokenResp.ExpiresIn
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
expiresIn = 3600
|
||||||
|
}
|
||||||
|
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ProfileArn: tokenResp.ProfileArn,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "social",
|
||||||
|
Provider: "", // Caller should preserve original provider
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken refreshes an expired access token.
|
||||||
|
func (o *KiroOAuth) RefreshToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) {
|
||||||
|
payload := map[string]string{
|
||||||
|
"refreshToken": refreshToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshURL := kiroAuthEndpoint + "/refreshToken"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
|
||||||
|
|
||||||
|
resp, err := o.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("refresh request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp KiroTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate ExpiresIn - use default 1 hour if invalid
|
||||||
|
expiresIn := tokenResp.ExpiresIn
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
expiresIn = 3600
|
||||||
|
}
|
||||||
|
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ProfileArn: tokenResp.ProfileArn,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "social",
|
||||||
|
Provider: "", // Caller should preserve original provider
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithGoogle performs OAuth login with Google using Kiro's social auth.
|
||||||
|
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||||
|
func (o *KiroOAuth) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
socialClient := NewSocialAuthClient(o.cfg)
|
||||||
|
return socialClient.LoginWithGoogle(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithGitHub performs OAuth login with GitHub using Kiro's social auth.
|
||||||
|
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||||
|
func (o *KiroOAuth) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
socialClient := NewSocialAuthClient(o.cfg)
|
||||||
|
return socialClient.LoginWithGitHub(ctx)
|
||||||
|
}
|
||||||
725
internal/auth/kiro/protocol_handler.go
Normal file
725
internal/auth/kiro/protocol_handler.go
Normal file
@@ -0,0 +1,725 @@
|
|||||||
|
// Package kiro provides custom protocol handler registration for Kiro OAuth.
|
||||||
|
// This enables the CLI to intercept kiro:// URIs for social authentication (Google/GitHub).
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"html"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// KiroProtocol is the custom URI scheme used by Kiro
|
||||||
|
KiroProtocol = "kiro"
|
||||||
|
|
||||||
|
// KiroAuthority is the URI authority for authentication callbacks
|
||||||
|
KiroAuthority = "kiro.kiroAgent"
|
||||||
|
|
||||||
|
// KiroAuthPath is the path for successful authentication
|
||||||
|
KiroAuthPath = "/authenticate-success"
|
||||||
|
|
||||||
|
// KiroRedirectURI is the full redirect URI for social auth
|
||||||
|
KiroRedirectURI = "kiro://kiro.kiroAgent/authenticate-success"
|
||||||
|
|
||||||
|
// DefaultHandlerPort is the default port for the local callback server
|
||||||
|
DefaultHandlerPort = 19876
|
||||||
|
|
||||||
|
// HandlerTimeout is how long to wait for the OAuth callback
|
||||||
|
HandlerTimeout = 10 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProtocolHandler manages the custom kiro:// protocol handler for OAuth callbacks.
|
||||||
|
type ProtocolHandler struct {
|
||||||
|
port int
|
||||||
|
server *http.Server
|
||||||
|
listener net.Listener
|
||||||
|
resultChan chan *AuthCallback
|
||||||
|
stopChan chan struct{}
|
||||||
|
mu sync.Mutex
|
||||||
|
running bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthCallback contains the OAuth callback parameters.
|
||||||
|
type AuthCallback struct {
|
||||||
|
Code string
|
||||||
|
State string
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProtocolHandler creates a new protocol handler.
|
||||||
|
func NewProtocolHandler() *ProtocolHandler {
|
||||||
|
return &ProtocolHandler{
|
||||||
|
port: DefaultHandlerPort,
|
||||||
|
resultChan: make(chan *AuthCallback, 1),
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the local callback server that receives redirects from the protocol handler.
|
||||||
|
func (h *ProtocolHandler) Start(ctx context.Context) (int, error) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
if h.running {
|
||||||
|
return h.port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain any stale results from previous runs
|
||||||
|
select {
|
||||||
|
case <-h.resultChan:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset stopChan for reuse - close old channel first to unblock any waiting goroutines
|
||||||
|
if h.stopChan != nil {
|
||||||
|
select {
|
||||||
|
case <-h.stopChan:
|
||||||
|
// Already closed
|
||||||
|
default:
|
||||||
|
close(h.stopChan)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.stopChan = make(chan struct{})
|
||||||
|
|
||||||
|
// Try ports in known range (must match handler script port range)
|
||||||
|
var listener net.Listener
|
||||||
|
var err error
|
||||||
|
portRange := []int{DefaultHandlerPort, DefaultHandlerPort + 1, DefaultHandlerPort + 2, DefaultHandlerPort + 3, DefaultHandlerPort + 4}
|
||||||
|
|
||||||
|
for _, port := range portRange {
|
||||||
|
listener, err = net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.Debugf("kiro protocol handler: port %d busy, trying next", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
if listener == nil {
|
||||||
|
return 0, fmt.Errorf("failed to start callback server: all ports %d-%d are busy", DefaultHandlerPort, DefaultHandlerPort+4)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.listener = listener
|
||||||
|
h.port = listener.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/oauth/callback", h.handleCallback)
|
||||||
|
|
||||||
|
h.server = &http.Server{
|
||||||
|
Handler: mux,
|
||||||
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := h.server.Serve(listener); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Debugf("kiro protocol handler server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
h.running = true
|
||||||
|
log.Debugf("kiro protocol handler started on port %d", h.port)
|
||||||
|
|
||||||
|
// Auto-shutdown after context done, timeout, or explicit stop
|
||||||
|
// Capture references to prevent race with new Start() calls
|
||||||
|
currentStopChan := h.stopChan
|
||||||
|
currentServer := h.server
|
||||||
|
currentListener := h.listener
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-time.After(HandlerTimeout):
|
||||||
|
case <-currentStopChan:
|
||||||
|
return // Already stopped, exit goroutine
|
||||||
|
}
|
||||||
|
// Only stop if this is still the current server/listener instance
|
||||||
|
h.mu.Lock()
|
||||||
|
if h.server == currentServer && h.listener == currentListener {
|
||||||
|
h.mu.Unlock()
|
||||||
|
h.Stop()
|
||||||
|
} else {
|
||||||
|
h.mu.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return h.port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the callback server.
|
||||||
|
func (h *ProtocolHandler) Stop() {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
|
if !h.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signal the auto-shutdown goroutine to exit.
|
||||||
|
// This select pattern is safe because stopChan is only modified while holding h.mu,
|
||||||
|
// and we hold the lock here. The select prevents panic from double-close.
|
||||||
|
select {
|
||||||
|
case <-h.stopChan:
|
||||||
|
// Already closed
|
||||||
|
default:
|
||||||
|
close(h.stopChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.server != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_ = h.server.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.running = false
|
||||||
|
log.Debug("kiro protocol handler stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitForCallback waits for the OAuth callback and returns the result.
|
||||||
|
func (h *ProtocolHandler) WaitForCallback(ctx context.Context) (*AuthCallback, error) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(HandlerTimeout):
|
||||||
|
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||||
|
case result := <-h.resultChan:
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPort returns the port the handler is listening on.
|
||||||
|
func (h *ProtocolHandler) GetPort() int {
|
||||||
|
return h.port
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleCallback processes the OAuth callback from the protocol handler script.
|
||||||
|
func (h *ProtocolHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
errParam := r.URL.Query().Get("error")
|
||||||
|
|
||||||
|
result := &AuthCallback{
|
||||||
|
Code: code,
|
||||||
|
State: state,
|
||||||
|
Error: errParam,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send result
|
||||||
|
select {
|
||||||
|
case h.resultChan <- result:
|
||||||
|
default:
|
||||||
|
// Channel full, ignore duplicate callbacks
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send success response
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
if errParam != "" {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprintf(w, `<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head><title>Login Failed</title></head>
|
||||||
|
<body>
|
||||||
|
<h1>Login Failed</h1>
|
||||||
|
<p>Error: %s</p>
|
||||||
|
<p>You can close this window.</p>
|
||||||
|
</body>
|
||||||
|
</html>`, html.EscapeString(errParam))
|
||||||
|
} else {
|
||||||
|
fmt.Fprint(w, `<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head><title>Login Successful</title></head>
|
||||||
|
<body>
|
||||||
|
<h1>Login Successful!</h1>
|
||||||
|
<p>You can close this window and return to the terminal.</p>
|
||||||
|
<script>window.close();</script>
|
||||||
|
</body>
|
||||||
|
</html>`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsProtocolHandlerInstalled checks if the kiro:// protocol handler is installed.
|
||||||
|
func IsProtocolHandlerInstalled() bool {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux":
|
||||||
|
return isLinuxHandlerInstalled()
|
||||||
|
case "windows":
|
||||||
|
return isWindowsHandlerInstalled()
|
||||||
|
case "darwin":
|
||||||
|
return isDarwinHandlerInstalled()
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InstallProtocolHandler installs the kiro:// protocol handler for the current platform.
|
||||||
|
func InstallProtocolHandler(handlerPort int) error {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux":
|
||||||
|
return installLinuxHandler(handlerPort)
|
||||||
|
case "windows":
|
||||||
|
return installWindowsHandler(handlerPort)
|
||||||
|
case "darwin":
|
||||||
|
return installDarwinHandler(handlerPort)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UninstallProtocolHandler removes the kiro:// protocol handler.
|
||||||
|
func UninstallProtocolHandler() error {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux":
|
||||||
|
return uninstallLinuxHandler()
|
||||||
|
case "windows":
|
||||||
|
return uninstallWindowsHandler()
|
||||||
|
case "darwin":
|
||||||
|
return uninstallDarwinHandler()
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Linux Implementation ---
|
||||||
|
|
||||||
|
func getLinuxDesktopPath() string {
|
||||||
|
homeDir, _ := os.UserHomeDir()
|
||||||
|
return filepath.Join(homeDir, ".local", "share", "applications", "kiro-oauth-handler.desktop")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLinuxHandlerScriptPath() string {
|
||||||
|
homeDir, _ := os.UserHomeDir()
|
||||||
|
return filepath.Join(homeDir, ".local", "bin", "kiro-oauth-handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLinuxHandlerInstalled() bool {
|
||||||
|
desktopPath := getLinuxDesktopPath()
|
||||||
|
_, err := os.Stat(desktopPath)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func installLinuxHandler(handlerPort int) error {
|
||||||
|
// Create directories
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
binDir := filepath.Join(homeDir, ".local", "bin")
|
||||||
|
appDir := filepath.Join(homeDir, ".local", "share", "applications")
|
||||||
|
|
||||||
|
if err := os.MkdirAll(binDir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to create bin directory: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(appDir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to create applications directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create handler script - tries multiple ports to handle dynamic port allocation
|
||||||
|
scriptPath := getLinuxHandlerScriptPath()
|
||||||
|
scriptContent := fmt.Sprintf(`#!/bin/bash
|
||||||
|
# Kiro OAuth Protocol Handler
|
||||||
|
# Handles kiro:// URIs - tries CLI first, then forwards to Kiro IDE
|
||||||
|
|
||||||
|
URL="$1"
|
||||||
|
|
||||||
|
# Check curl availability
|
||||||
|
if ! command -v curl &> /dev/null; then
|
||||||
|
echo "Error: curl is required for Kiro OAuth handler" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract code and state from URL
|
||||||
|
[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}"
|
||||||
|
[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}"
|
||||||
|
[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}"
|
||||||
|
|
||||||
|
# Try CLI proxy on multiple possible ports (default + dynamic range)
|
||||||
|
CLI_OK=0
|
||||||
|
for PORT in %d %d %d %d %d; do
|
||||||
|
if [ -n "$ERROR" ]; then
|
||||||
|
curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && CLI_OK=1 && break
|
||||||
|
elif [ -n "$CODE" ] && [ -n "$STATE" ]; then
|
||||||
|
curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && CLI_OK=1 && break
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# If CLI not available, forward to Kiro IDE
|
||||||
|
if [ $CLI_OK -eq 0 ] && [ -x "/usr/share/kiro/kiro" ]; then
|
||||||
|
/usr/share/kiro/kiro --open-url "$URL" &
|
||||||
|
fi
|
||||||
|
`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
|
||||||
|
|
||||||
|
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to write handler script: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create .desktop file
|
||||||
|
desktopPath := getLinuxDesktopPath()
|
||||||
|
desktopContent := fmt.Sprintf(`[Desktop Entry]
|
||||||
|
Name=Kiro OAuth Handler
|
||||||
|
Comment=Handle kiro:// protocol for CLI Proxy API authentication
|
||||||
|
Exec=%s %%u
|
||||||
|
Type=Application
|
||||||
|
Terminal=false
|
||||||
|
NoDisplay=true
|
||||||
|
MimeType=x-scheme-handler/kiro;
|
||||||
|
Categories=Utility;
|
||||||
|
`, scriptPath)
|
||||||
|
|
||||||
|
if err := os.WriteFile(desktopPath, []byte(desktopContent), 0644); err != nil {
|
||||||
|
return fmt.Errorf("failed to write desktop file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register handler with xdg-mime
|
||||||
|
cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
log.Warnf("xdg-mime registration failed (may need manual setup): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update desktop database
|
||||||
|
cmd = exec.Command("update-desktop-database", appDir)
|
||||||
|
_ = cmd.Run() // Ignore errors, not critical
|
||||||
|
|
||||||
|
log.Info("Kiro protocol handler installed for Linux")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func uninstallLinuxHandler() error {
|
||||||
|
desktopPath := getLinuxDesktopPath()
|
||||||
|
scriptPath := getLinuxHandlerScriptPath()
|
||||||
|
|
||||||
|
if err := os.Remove(desktopPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("failed to remove desktop file: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.Remove(scriptPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("failed to remove handler script: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Kiro protocol handler uninstalled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Windows Implementation ---
|
||||||
|
|
||||||
|
func isWindowsHandlerInstalled() bool {
|
||||||
|
// Check registry key existence
|
||||||
|
cmd := exec.Command("reg", "query", `HKCU\Software\Classes\kiro`, "/ve")
|
||||||
|
return cmd.Run() == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func installWindowsHandler(handlerPort int) error {
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create handler script (PowerShell)
|
||||||
|
scriptDir := filepath.Join(homeDir, ".cliproxyapi")
|
||||||
|
if err := os.MkdirAll(scriptDir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to create script directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
scriptPath := filepath.Join(scriptDir, "kiro-oauth-handler.ps1")
|
||||||
|
scriptContent := fmt.Sprintf(`# Kiro OAuth Protocol Handler for Windows
|
||||||
|
param([string]$url)
|
||||||
|
|
||||||
|
# Load required assembly for HttpUtility
|
||||||
|
Add-Type -AssemblyName System.Web
|
||||||
|
|
||||||
|
# Parse URL parameters
|
||||||
|
$uri = [System.Uri]$url
|
||||||
|
$query = [System.Web.HttpUtility]::ParseQueryString($uri.Query)
|
||||||
|
$code = $query["code"]
|
||||||
|
$state = $query["state"]
|
||||||
|
$errorParam = $query["error"]
|
||||||
|
|
||||||
|
# Try multiple ports (default + dynamic range)
|
||||||
|
$ports = @(%d, %d, %d, %d, %d)
|
||||||
|
$success = $false
|
||||||
|
|
||||||
|
foreach ($port in $ports) {
|
||||||
|
if ($success) { break }
|
||||||
|
$callbackUrl = "http://127.0.0.1:$port/oauth/callback"
|
||||||
|
try {
|
||||||
|
if ($errorParam) {
|
||||||
|
$fullUrl = $callbackUrl + "?error=" + $errorParam
|
||||||
|
Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null
|
||||||
|
$success = $true
|
||||||
|
} elseif ($code -and $state) {
|
||||||
|
$fullUrl = $callbackUrl + "?code=" + $code + "&state=" + $state
|
||||||
|
Invoke-WebRequest -Uri $fullUrl -UseBasicParsing -TimeoutSec 1 -ErrorAction Stop | Out-Null
|
||||||
|
$success = $true
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
# Try next port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
|
||||||
|
|
||||||
|
if err := os.WriteFile(scriptPath, []byte(scriptContent), 0644); err != nil {
|
||||||
|
return fmt.Errorf("failed to write handler script: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create batch wrapper
|
||||||
|
batchPath := filepath.Join(scriptDir, "kiro-oauth-handler.bat")
|
||||||
|
batchContent := fmt.Sprintf("@echo off\npowershell -ExecutionPolicy Bypass -File \"%s\" %%1\n", scriptPath)
|
||||||
|
|
||||||
|
if err := os.WriteFile(batchPath, []byte(batchContent), 0644); err != nil {
|
||||||
|
return fmt.Errorf("failed to write batch wrapper: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register in Windows registry
|
||||||
|
commands := [][]string{
|
||||||
|
{"reg", "add", `HKCU\Software\Classes\kiro`, "/ve", "/d", "URL:Kiro Protocol", "/f"},
|
||||||
|
{"reg", "add", `HKCU\Software\Classes\kiro`, "/v", "URL Protocol", "/d", "", "/f"},
|
||||||
|
{"reg", "add", `HKCU\Software\Classes\kiro\shell`, "/f"},
|
||||||
|
{"reg", "add", `HKCU\Software\Classes\kiro\shell\open`, "/f"},
|
||||||
|
{"reg", "add", `HKCU\Software\Classes\kiro\shell\open\command`, "/ve", "/d", fmt.Sprintf("\"%s\" \"%%1\"", batchPath), "/f"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, args := range commands {
|
||||||
|
cmd := exec.Command(args[0], args[1:]...)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
return fmt.Errorf("failed to run registry command: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Kiro protocol handler installed for Windows")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func uninstallWindowsHandler() error {
|
||||||
|
// Remove registry keys
|
||||||
|
cmd := exec.Command("reg", "delete", `HKCU\Software\Classes\kiro`, "/f")
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
log.Warnf("failed to remove registry key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove scripts
|
||||||
|
homeDir, _ := os.UserHomeDir()
|
||||||
|
scriptDir := filepath.Join(homeDir, ".cliproxyapi")
|
||||||
|
_ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.ps1"))
|
||||||
|
_ = os.Remove(filepath.Join(scriptDir, "kiro-oauth-handler.bat"))
|
||||||
|
|
||||||
|
log.Info("Kiro protocol handler uninstalled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- macOS Implementation ---
|
||||||
|
|
||||||
|
func getDarwinAppPath() string {
|
||||||
|
homeDir, _ := os.UserHomeDir()
|
||||||
|
return filepath.Join(homeDir, "Applications", "KiroOAuthHandler.app")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDarwinHandlerInstalled() bool {
|
||||||
|
appPath := getDarwinAppPath()
|
||||||
|
_, err := os.Stat(appPath)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func installDarwinHandler(handlerPort int) error {
|
||||||
|
// Create app bundle structure
|
||||||
|
appPath := getDarwinAppPath()
|
||||||
|
contentsPath := filepath.Join(appPath, "Contents")
|
||||||
|
macOSPath := filepath.Join(contentsPath, "MacOS")
|
||||||
|
|
||||||
|
if err := os.MkdirAll(macOSPath, 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to create app bundle: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create Info.plist
|
||||||
|
plistPath := filepath.Join(contentsPath, "Info.plist")
|
||||||
|
plistContent := `<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||||
|
<plist version="1.0">
|
||||||
|
<dict>
|
||||||
|
<key>CFBundleIdentifier</key>
|
||||||
|
<string>com.cliproxyapi.kiro-oauth-handler</string>
|
||||||
|
<key>CFBundleName</key>
|
||||||
|
<string>KiroOAuthHandler</string>
|
||||||
|
<key>CFBundleExecutable</key>
|
||||||
|
<string>kiro-oauth-handler</string>
|
||||||
|
<key>CFBundleVersion</key>
|
||||||
|
<string>1.0</string>
|
||||||
|
<key>CFBundleURLTypes</key>
|
||||||
|
<array>
|
||||||
|
<dict>
|
||||||
|
<key>CFBundleURLName</key>
|
||||||
|
<string>Kiro Protocol</string>
|
||||||
|
<key>CFBundleURLSchemes</key>
|
||||||
|
<array>
|
||||||
|
<string>kiro</string>
|
||||||
|
</array>
|
||||||
|
</dict>
|
||||||
|
</array>
|
||||||
|
<key>LSBackgroundOnly</key>
|
||||||
|
<true/>
|
||||||
|
</dict>
|
||||||
|
</plist>`
|
||||||
|
|
||||||
|
if err := os.WriteFile(plistPath, []byte(plistContent), 0644); err != nil {
|
||||||
|
return fmt.Errorf("failed to write Info.plist: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create executable script - tries multiple ports to handle dynamic port allocation
|
||||||
|
execPath := filepath.Join(macOSPath, "kiro-oauth-handler")
|
||||||
|
execContent := fmt.Sprintf(`#!/bin/bash
|
||||||
|
# Kiro OAuth Protocol Handler for macOS
|
||||||
|
|
||||||
|
URL="$1"
|
||||||
|
|
||||||
|
# Check curl availability (should always exist on macOS)
|
||||||
|
if [ ! -x /usr/bin/curl ]; then
|
||||||
|
echo "Error: curl is required for Kiro OAuth handler" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract code and state from URL
|
||||||
|
[[ "$URL" =~ code=([^&]+) ]] && CODE="${BASH_REMATCH[1]}"
|
||||||
|
[[ "$URL" =~ state=([^&]+) ]] && STATE="${BASH_REMATCH[1]}"
|
||||||
|
[[ "$URL" =~ error=([^&]+) ]] && ERROR="${BASH_REMATCH[1]}"
|
||||||
|
|
||||||
|
# Try multiple ports (default + dynamic range)
|
||||||
|
for PORT in %d %d %d %d %d; do
|
||||||
|
if [ -n "$ERROR" ]; then
|
||||||
|
/usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?error=$ERROR" && exit 0
|
||||||
|
elif [ -n "$CODE" ] && [ -n "$STATE" ]; then
|
||||||
|
/usr/bin/curl -sf --connect-timeout 1 "http://127.0.0.1:$PORT/oauth/callback?code=$CODE&state=$STATE" && exit 0
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
`, handlerPort, handlerPort+1, handlerPort+2, handlerPort+3, handlerPort+4)
|
||||||
|
|
||||||
|
if err := os.WriteFile(execPath, []byte(execContent), 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to write executable: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register the app with Launch Services
|
||||||
|
cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister",
|
||||||
|
"-f", appPath)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
log.Warnf("lsregister failed (handler may still work): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Kiro protocol handler installed for macOS")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func uninstallDarwinHandler() error {
|
||||||
|
appPath := getDarwinAppPath()
|
||||||
|
|
||||||
|
// Unregister from Launch Services
|
||||||
|
cmd := exec.Command("/System/Library/Frameworks/CoreServices.framework/Frameworks/LaunchServices.framework/Support/lsregister",
|
||||||
|
"-u", appPath)
|
||||||
|
_ = cmd.Run()
|
||||||
|
|
||||||
|
// Remove app bundle
|
||||||
|
if err := os.RemoveAll(appPath); err != nil && !os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("failed to remove app bundle: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Kiro protocol handler uninstalled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseKiroURI parses a kiro:// URI and extracts the callback parameters.
|
||||||
|
func ParseKiroURI(rawURI string) (*AuthCallback, error) {
|
||||||
|
u, err := url.Parse(rawURI)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid URI: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Scheme != KiroProtocol {
|
||||||
|
return nil, fmt.Errorf("invalid scheme: expected %s, got %s", KiroProtocol, u.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.Host != KiroAuthority {
|
||||||
|
return nil, fmt.Errorf("invalid authority: expected %s, got %s", KiroAuthority, u.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
query := u.Query()
|
||||||
|
return &AuthCallback{
|
||||||
|
Code: query.Get("code"),
|
||||||
|
State: query.Get("state"),
|
||||||
|
Error: query.Get("error"),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHandlerInstructions returns platform-specific instructions for manual handler setup.
|
||||||
|
func GetHandlerInstructions() string {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux":
|
||||||
|
return `To manually set up the Kiro protocol handler on Linux:
|
||||||
|
|
||||||
|
1. Create ~/.local/share/applications/kiro-oauth-handler.desktop:
|
||||||
|
[Desktop Entry]
|
||||||
|
Name=Kiro OAuth Handler
|
||||||
|
Exec=~/.local/bin/kiro-oauth-handler %u
|
||||||
|
Type=Application
|
||||||
|
Terminal=false
|
||||||
|
MimeType=x-scheme-handler/kiro;
|
||||||
|
|
||||||
|
2. Create ~/.local/bin/kiro-oauth-handler (make it executable):
|
||||||
|
#!/bin/bash
|
||||||
|
URL="$1"
|
||||||
|
# ... (see generated script for full content)
|
||||||
|
|
||||||
|
3. Run: xdg-mime default kiro-oauth-handler.desktop x-scheme-handler/kiro`
|
||||||
|
|
||||||
|
case "windows":
|
||||||
|
return `To manually set up the Kiro protocol handler on Windows:
|
||||||
|
|
||||||
|
1. Open Registry Editor (regedit.exe)
|
||||||
|
2. Create key: HKEY_CURRENT_USER\Software\Classes\kiro
|
||||||
|
3. Set default value to: URL:Kiro Protocol
|
||||||
|
4. Create string value "URL Protocol" with empty data
|
||||||
|
5. Create subkey: shell\open\command
|
||||||
|
6. Set default value to: "C:\path\to\handler.bat" "%1"`
|
||||||
|
|
||||||
|
case "darwin":
|
||||||
|
return `To manually set up the Kiro protocol handler on macOS:
|
||||||
|
|
||||||
|
1. Create ~/Applications/KiroOAuthHandler.app bundle
|
||||||
|
2. Add Info.plist with CFBundleURLTypes containing "kiro" scheme
|
||||||
|
3. Create executable in Contents/MacOS/
|
||||||
|
4. Run: /System/Library/.../lsregister -f ~/Applications/KiroOAuthHandler.app`
|
||||||
|
|
||||||
|
default:
|
||||||
|
return "Protocol handler setup is not supported on this platform."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupProtocolHandlerIfNeeded checks and installs the protocol handler if needed.
|
||||||
|
func SetupProtocolHandlerIfNeeded(handlerPort int) error {
|
||||||
|
if IsProtocolHandlerInstalled() {
|
||||||
|
log.Debug("Kiro protocol handler already installed")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||||
|
fmt.Println("║ Kiro Protocol Handler Setup Required ║")
|
||||||
|
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||||
|
fmt.Println("\nTo enable Google/GitHub login, we need to install a protocol handler.")
|
||||||
|
fmt.Println("This allows your browser to redirect back to the CLI after authentication.")
|
||||||
|
fmt.Println("\nInstalling protocol handler...")
|
||||||
|
|
||||||
|
if err := InstallProtocolHandler(handlerPort); err != nil {
|
||||||
|
fmt.Printf("\n⚠ Automatic installation failed: %v\n", err)
|
||||||
|
fmt.Println("\nManual setup instructions:")
|
||||||
|
fmt.Println(strings.Repeat("-", 60))
|
||||||
|
fmt.Println(GetHandlerInstructions())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n✓ Protocol handler installed successfully!")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
403
internal/auth/kiro/social_auth.go
Normal file
403
internal/auth/kiro/social_auth.go
Normal file
@@ -0,0 +1,403 @@
|
|||||||
|
// Package kiro provides social authentication (Google/GitHub) for Kiro via AuthServiceClient.
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/term"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Kiro AuthService endpoint
|
||||||
|
kiroAuthServiceEndpoint = "https://prod.us-east-1.auth.desktop.kiro.dev"
|
||||||
|
|
||||||
|
// OAuth timeout
|
||||||
|
socialAuthTimeout = 10 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// SocialProvider represents the social login provider.
|
||||||
|
type SocialProvider string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ProviderGoogle is Google OAuth provider
|
||||||
|
ProviderGoogle SocialProvider = "Google"
|
||||||
|
// ProviderGitHub is GitHub OAuth provider
|
||||||
|
ProviderGitHub SocialProvider = "Github"
|
||||||
|
// Note: AWS Builder ID is NOT supported by Kiro's auth service.
|
||||||
|
// It only supports: Google, Github, Cognito
|
||||||
|
// AWS Builder ID must use device code flow via SSO OIDC.
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateTokenRequest is sent to Kiro's /oauth/token endpoint.
|
||||||
|
type CreateTokenRequest struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
CodeVerifier string `json:"code_verifier"`
|
||||||
|
RedirectURI string `json:"redirect_uri"`
|
||||||
|
InvitationCode string `json:"invitation_code,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SocialTokenResponse from Kiro's /oauth/token endpoint for social auth.
|
||||||
|
type SocialTokenResponse struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
ProfileArn string `json:"profileArn"`
|
||||||
|
ExpiresIn int `json:"expiresIn"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokenRequest is sent to Kiro's /refreshToken endpoint.
|
||||||
|
type RefreshTokenRequest struct {
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SocialAuthClient handles social authentication with Kiro.
|
||||||
|
type SocialAuthClient struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
cfg *config.Config
|
||||||
|
protocolHandler *ProtocolHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSocialAuthClient creates a new social auth client.
|
||||||
|
func NewSocialAuthClient(cfg *config.Config) *SocialAuthClient {
|
||||||
|
client := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
if cfg != nil {
|
||||||
|
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||||
|
}
|
||||||
|
return &SocialAuthClient{
|
||||||
|
httpClient: client,
|
||||||
|
cfg: cfg,
|
||||||
|
protocolHandler: NewProtocolHandler(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generatePKCE generates PKCE code verifier and challenge.
|
||||||
|
func generatePKCE() (verifier, challenge string, err error) {
|
||||||
|
// Generate 32 bytes of random data for verifier
|
||||||
|
b := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||||
|
}
|
||||||
|
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||||
|
|
||||||
|
// Generate SHA256 hash of verifier for challenge
|
||||||
|
h := sha256.Sum256([]byte(verifier))
|
||||||
|
challenge = base64.RawURLEncoding.EncodeToString(h[:])
|
||||||
|
|
||||||
|
return verifier, challenge, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateState generates a random state parameter.
|
||||||
|
func generateStateParam() (string, error) {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildLoginURL constructs the Kiro OAuth login URL.
|
||||||
|
// The login endpoint expects a GET request with query parameters.
|
||||||
|
// Format: /login?idp=Google&redirect_uri=...&code_challenge=...&code_challenge_method=S256&state=...&prompt=select_account
|
||||||
|
// The prompt=select_account parameter forces the account selection screen even if already logged in.
|
||||||
|
func (c *SocialAuthClient) buildLoginURL(provider, redirectURI, codeChallenge, state string) string {
|
||||||
|
return fmt.Sprintf("%s/login?idp=%s&redirect_uri=%s&code_challenge=%s&code_challenge_method=S256&state=%s&prompt=select_account",
|
||||||
|
kiroAuthServiceEndpoint,
|
||||||
|
provider,
|
||||||
|
url.QueryEscape(redirectURI),
|
||||||
|
codeChallenge,
|
||||||
|
state,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateToken exchanges the authorization code for tokens.
|
||||||
|
func (c *SocialAuthClient) CreateToken(ctx context.Context, req *CreateTokenRequest) (*SocialTokenResponse, error) {
|
||||||
|
body, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal token request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenURL := kiroAuthServiceEndpoint + "/oauth/token"
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("token request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read token response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("token exchange failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("token exchange failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp SocialTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokenResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshSocialToken refreshes an expired social auth token.
|
||||||
|
func (c *SocialAuthClient) RefreshSocialToken(ctx context.Context, refreshToken string) (*KiroTokenData, error) {
|
||||||
|
body, err := json.Marshal(&RefreshTokenRequest{RefreshToken: refreshToken})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal refresh request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshURL := kiroAuthServiceEndpoint + "/refreshToken"
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, refreshURL, strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
httpReq.Header.Set("User-Agent", "cli-proxy-api/1.0.0")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("refresh request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp SocialTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &tokenResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate ExpiresIn - use default 1 hour if invalid
|
||||||
|
expiresIn := tokenResp.ExpiresIn
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
expiresIn = 3600 // Default 1 hour
|
||||||
|
}
|
||||||
|
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ProfileArn: tokenResp.ProfileArn,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "social",
|
||||||
|
Provider: "", // Caller should preserve original provider
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithSocial performs OAuth login with Google.
|
||||||
|
func (c *SocialAuthClient) LoginWithSocial(ctx context.Context, provider SocialProvider) (*KiroTokenData, error) {
|
||||||
|
providerName := string(provider)
|
||||||
|
|
||||||
|
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||||
|
fmt.Printf("║ Kiro Authentication (%s) ║\n", providerName)
|
||||||
|
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||||
|
|
||||||
|
// Step 1: Setup protocol handler
|
||||||
|
fmt.Println("\nSetting up authentication...")
|
||||||
|
|
||||||
|
// Start the local callback server
|
||||||
|
handlerPort, err := c.protocolHandler.Start(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to start callback server: %w", err)
|
||||||
|
}
|
||||||
|
defer c.protocolHandler.Stop()
|
||||||
|
|
||||||
|
// Ensure protocol handler is installed and set as default
|
||||||
|
if err := SetupProtocolHandlerIfNeeded(handlerPort); err != nil {
|
||||||
|
fmt.Println("\n⚠ Protocol handler setup failed. Trying alternative method...")
|
||||||
|
fmt.Println(" If you see a browser 'Open with' dialog, select your default browser.")
|
||||||
|
fmt.Println(" For manual setup instructions, run: cliproxy kiro --help-protocol")
|
||||||
|
log.Debugf("kiro: protocol handler setup error: %v", err)
|
||||||
|
// Continue anyway - user might have set it up manually or select browser manually
|
||||||
|
} else {
|
||||||
|
// Force set our handler as default (prevents "Open with" dialog)
|
||||||
|
forceDefaultProtocolHandler()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Generate PKCE codes
|
||||||
|
codeVerifier, codeChallenge, err := generatePKCE()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate PKCE: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Generate state
|
||||||
|
state, err := generateStateParam()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: Build the login URL (Kiro uses GET request with query params)
|
||||||
|
authURL := c.buildLoginURL(providerName, KiroRedirectURI, codeChallenge, state)
|
||||||
|
|
||||||
|
// Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
|
||||||
|
// Incognito mode enables multi-account support by bypassing cached sessions
|
||||||
|
if c.cfg != nil {
|
||||||
|
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
|
||||||
|
if !c.cfg.IncognitoBrowser {
|
||||||
|
log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
|
||||||
|
} else {
|
||||||
|
log.Debug("kiro: using incognito mode for multi-account support")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
browser.SetIncognitoMode(true) // Default to incognito if no config
|
||||||
|
log.Debug("kiro: using incognito mode for multi-account support (default)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 5: Open browser for user authentication
|
||||||
|
fmt.Println("\n════════════════════════════════════════════════════════════")
|
||||||
|
fmt.Printf(" Opening browser for %s authentication...\n", providerName)
|
||||||
|
fmt.Println("════════════════════════════════════════════════════════════")
|
||||||
|
fmt.Printf("\n URL: %s\n\n", authURL)
|
||||||
|
|
||||||
|
if err := browser.OpenURL(authURL); err != nil {
|
||||||
|
log.Warnf("Could not open browser automatically: %v", err)
|
||||||
|
fmt.Println(" ⚠ Could not open browser automatically.")
|
||||||
|
fmt.Println(" Please open the URL above in your browser manually.")
|
||||||
|
} else {
|
||||||
|
fmt.Println(" (Browser opened automatically)")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n Waiting for authentication callback...")
|
||||||
|
|
||||||
|
// Step 6: Wait for callback
|
||||||
|
callback, err := c.protocolHandler.WaitForCallback(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to receive callback: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if callback.Error != "" {
|
||||||
|
return nil, fmt.Errorf("authentication error: %s", callback.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if callback.State != state {
|
||||||
|
// Log state values for debugging, but don't expose in user-facing error
|
||||||
|
log.Debugf("kiro: OAuth state mismatch - expected %s, got %s", state, callback.State)
|
||||||
|
return nil, fmt.Errorf("OAuth state validation failed - please try again")
|
||||||
|
}
|
||||||
|
|
||||||
|
if callback.Code == "" {
|
||||||
|
return nil, fmt.Errorf("no authorization code received")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n✓ Authorization received!")
|
||||||
|
|
||||||
|
// Step 7: Exchange code for tokens
|
||||||
|
fmt.Println("Exchanging code for tokens...")
|
||||||
|
|
||||||
|
tokenReq := &CreateTokenRequest{
|
||||||
|
Code: callback.Code,
|
||||||
|
CodeVerifier: codeVerifier,
|
||||||
|
RedirectURI: KiroRedirectURI,
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResp, err := c.CreateToken(ctx, tokenReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to exchange code for tokens: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n✓ Authentication successful!")
|
||||||
|
|
||||||
|
// Close the browser window
|
||||||
|
if err := browser.CloseBrowser(); err != nil {
|
||||||
|
log.Debugf("Failed to close browser: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate ExpiresIn - use default 1 hour if invalid
|
||||||
|
expiresIn := tokenResp.ExpiresIn
|
||||||
|
if expiresIn <= 0 {
|
||||||
|
expiresIn = 3600
|
||||||
|
}
|
||||||
|
expiresAt := time.Now().Add(time.Duration(expiresIn) * time.Second)
|
||||||
|
|
||||||
|
// Try to extract email from JWT access token first
|
||||||
|
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||||
|
|
||||||
|
// If no email in JWT, ask user for account label (only in interactive mode)
|
||||||
|
if email == "" && isInteractiveTerminal() {
|
||||||
|
fmt.Print("\n Enter account label for file naming (optional, press Enter to skip): ")
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
var err error
|
||||||
|
email, err = reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Failed to read account label: %v", err)
|
||||||
|
}
|
||||||
|
email = strings.TrimSpace(email)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ProfileArn: tokenResp.ProfileArn,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "social",
|
||||||
|
Provider: providerName,
|
||||||
|
Email: email, // JWT email or user-provided label
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithGoogle performs OAuth login with Google.
|
||||||
|
func (c *SocialAuthClient) LoginWithGoogle(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
return c.LoginWithSocial(ctx, ProviderGoogle)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithGitHub performs OAuth login with GitHub.
|
||||||
|
func (c *SocialAuthClient) LoginWithGitHub(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
return c.LoginWithSocial(ctx, ProviderGitHub)
|
||||||
|
}
|
||||||
|
|
||||||
|
// forceDefaultProtocolHandler sets our protocol handler as the default for kiro:// URLs.
|
||||||
|
// This prevents the "Open with" dialog from appearing on Linux.
|
||||||
|
// On non-Linux platforms, this is a no-op as they use different mechanisms.
|
||||||
|
func forceDefaultProtocolHandler() {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return // Non-Linux platforms use different handler mechanisms
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set our handler as default using xdg-mime
|
||||||
|
cmd := exec.Command("xdg-mime", "default", "kiro-oauth-handler.desktop", "x-scheme-handler/kiro")
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
log.Warnf("Failed to set default protocol handler: %v. You may see a handler selection dialog.", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isInteractiveTerminal checks if stdin is connected to an interactive terminal.
|
||||||
|
// Returns false in CI/automated environments or when stdin is piped.
|
||||||
|
func isInteractiveTerminal() bool {
|
||||||
|
return term.IsTerminal(int(os.Stdin.Fd()))
|
||||||
|
}
|
||||||
527
internal/auth/kiro/sso_oidc.go
Normal file
527
internal/auth/kiro/sso_oidc.go
Normal file
@@ -0,0 +1,527 @@
|
|||||||
|
// Package kiro provides AWS SSO OIDC authentication for Kiro.
|
||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// AWS SSO OIDC endpoints
|
||||||
|
ssoOIDCEndpoint = "https://oidc.us-east-1.amazonaws.com"
|
||||||
|
|
||||||
|
// Kiro's start URL for Builder ID
|
||||||
|
builderIDStartURL = "https://view.awsapps.com/start"
|
||||||
|
|
||||||
|
// Polling interval
|
||||||
|
pollInterval = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// SSOOIDCClient handles AWS SSO OIDC authentication.
|
||||||
|
type SSOOIDCClient struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSSOOIDCClient creates a new SSO OIDC client.
|
||||||
|
func NewSSOOIDCClient(cfg *config.Config) *SSOOIDCClient {
|
||||||
|
client := &http.Client{Timeout: 30 * time.Second}
|
||||||
|
if cfg != nil {
|
||||||
|
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||||
|
}
|
||||||
|
return &SSOOIDCClient{
|
||||||
|
httpClient: client,
|
||||||
|
cfg: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterClientResponse from AWS SSO OIDC.
|
||||||
|
type RegisterClientResponse struct {
|
||||||
|
ClientID string `json:"clientId"`
|
||||||
|
ClientSecret string `json:"clientSecret"`
|
||||||
|
ClientIDIssuedAt int64 `json:"clientIdIssuedAt"`
|
||||||
|
ClientSecretExpiresAt int64 `json:"clientSecretExpiresAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartDeviceAuthResponse from AWS SSO OIDC.
|
||||||
|
type StartDeviceAuthResponse struct {
|
||||||
|
DeviceCode string `json:"deviceCode"`
|
||||||
|
UserCode string `json:"userCode"`
|
||||||
|
VerificationURI string `json:"verificationUri"`
|
||||||
|
VerificationURIComplete string `json:"verificationUriComplete"`
|
||||||
|
ExpiresIn int `json:"expiresIn"`
|
||||||
|
Interval int `json:"interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTokenResponse from AWS SSO OIDC.
|
||||||
|
type CreateTokenResponse struct {
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
TokenType string `json:"tokenType"`
|
||||||
|
ExpiresIn int `json:"expiresIn"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterClient registers a new OIDC client with AWS.
|
||||||
|
func (c *SSOOIDCClient) RegisterClient(ctx context.Context) (*RegisterClientResponse, error) {
|
||||||
|
// Generate unique client name for each registration to support multiple accounts
|
||||||
|
clientName := fmt.Sprintf("CLI-Proxy-API-%d", time.Now().UnixNano())
|
||||||
|
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"clientName": clientName,
|
||||||
|
"clientType": "public",
|
||||||
|
"scopes": []string{"codewhisperer:completions", "codewhisperer:analysis", "codewhisperer:conversations"},
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/client/register", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("register client failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("register client failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result RegisterClientResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartDeviceAuthorization starts the device authorization flow.
|
||||||
|
func (c *SSOOIDCClient) StartDeviceAuthorization(ctx context.Context, clientID, clientSecret string) (*StartDeviceAuthResponse, error) {
|
||||||
|
payload := map[string]string{
|
||||||
|
"clientId": clientID,
|
||||||
|
"clientSecret": clientSecret,
|
||||||
|
"startUrl": builderIDStartURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/device_authorization", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("start device auth failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("start device auth failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result StartDeviceAuthResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateToken polls for the access token after user authorization.
|
||||||
|
func (c *SSOOIDCClient) CreateToken(ctx context.Context, clientID, clientSecret, deviceCode string) (*CreateTokenResponse, error) {
|
||||||
|
payload := map[string]string{
|
||||||
|
"clientId": clientID,
|
||||||
|
"clientSecret": clientSecret,
|
||||||
|
"deviceCode": deviceCode,
|
||||||
|
"grantType": "urn:ietf:params:oauth:grant-type:device_code",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for pending authorization
|
||||||
|
if resp.StatusCode == http.StatusBadRequest {
|
||||||
|
var errResp struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
if json.Unmarshal(respBody, &errResp) == nil {
|
||||||
|
if errResp.Error == "authorization_pending" {
|
||||||
|
return nil, fmt.Errorf("authorization_pending")
|
||||||
|
}
|
||||||
|
if errResp.Error == "slow_down" {
|
||||||
|
return nil, fmt.Errorf("slow_down")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debugf("create token failed: %s", string(respBody))
|
||||||
|
return nil, fmt.Errorf("create token failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("create token failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("create token failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result CreateTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken refreshes an access token using the refresh token.
|
||||||
|
func (c *SSOOIDCClient) RefreshToken(ctx context.Context, clientID, clientSecret, refreshToken string) (*KiroTokenData, error) {
|
||||||
|
payload := map[string]string{
|
||||||
|
"clientId": clientID,
|
||||||
|
"clientSecret": clientSecret,
|
||||||
|
"refreshToken": refreshToken,
|
||||||
|
"grantType": "refresh_token",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, ssoOIDCEndpoint+"/token", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("token refresh failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return nil, fmt.Errorf("token refresh failed (status %d)", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result CreateTokenResponse
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(time.Duration(result.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: result.AccessToken,
|
||||||
|
RefreshToken: result.RefreshToken,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "builder-id",
|
||||||
|
Provider: "AWS",
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithBuilderID performs the full device code flow for AWS Builder ID.
|
||||||
|
func (c *SSOOIDCClient) LoginWithBuilderID(ctx context.Context) (*KiroTokenData, error) {
|
||||||
|
fmt.Println("\n╔══════════════════════════════════════════════════════════╗")
|
||||||
|
fmt.Println("║ Kiro Authentication (AWS Builder ID) ║")
|
||||||
|
fmt.Println("╚══════════════════════════════════════════════════════════╝")
|
||||||
|
|
||||||
|
// Step 1: Register client
|
||||||
|
fmt.Println("\nRegistering client...")
|
||||||
|
regResp, err := c.RegisterClient(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to register client: %w", err)
|
||||||
|
}
|
||||||
|
log.Debugf("Client registered: %s", regResp.ClientID)
|
||||||
|
|
||||||
|
// Step 2: Start device authorization
|
||||||
|
fmt.Println("Starting device authorization...")
|
||||||
|
authResp, err := c.StartDeviceAuthorization(ctx, regResp.ClientID, regResp.ClientSecret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to start device auth: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Show user the verification URL
|
||||||
|
fmt.Printf("\n")
|
||||||
|
fmt.Println("════════════════════════════════════════════════════════════")
|
||||||
|
fmt.Printf(" Open this URL in your browser:\n")
|
||||||
|
fmt.Printf(" %s\n", authResp.VerificationURIComplete)
|
||||||
|
fmt.Println("════════════════════════════════════════════════════════════")
|
||||||
|
fmt.Printf("\n Or go to: %s\n", authResp.VerificationURI)
|
||||||
|
fmt.Printf(" And enter code: %s\n\n", authResp.UserCode)
|
||||||
|
|
||||||
|
// Set incognito mode based on config (defaults to true for Kiro, can be overridden with --no-incognito)
|
||||||
|
// Incognito mode enables multi-account support by bypassing cached sessions
|
||||||
|
if c.cfg != nil {
|
||||||
|
browser.SetIncognitoMode(c.cfg.IncognitoBrowser)
|
||||||
|
if !c.cfg.IncognitoBrowser {
|
||||||
|
log.Info("kiro: using normal browser mode (--no-incognito). Note: You may not be able to select a different account.")
|
||||||
|
} else {
|
||||||
|
log.Debug("kiro: using incognito mode for multi-account support")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
browser.SetIncognitoMode(true) // Default to incognito if no config
|
||||||
|
log.Debug("kiro: using incognito mode for multi-account support (default)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open browser using cross-platform browser package
|
||||||
|
if err := browser.OpenURL(authResp.VerificationURIComplete); err != nil {
|
||||||
|
log.Warnf("Could not open browser automatically: %v", err)
|
||||||
|
fmt.Println(" Please open the URL manually in your browser.")
|
||||||
|
} else {
|
||||||
|
fmt.Println(" (Browser opened automatically)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: Poll for token
|
||||||
|
fmt.Println("Waiting for authorization...")
|
||||||
|
|
||||||
|
interval := pollInterval
|
||||||
|
if authResp.Interval > 0 {
|
||||||
|
interval = time.Duration(authResp.Interval) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
deadline := time.Now().Add(time.Duration(authResp.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
browser.CloseBrowser() // Cleanup on cancel
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(interval):
|
||||||
|
tokenResp, err := c.CreateToken(ctx, regResp.ClientID, regResp.ClientSecret, authResp.DeviceCode)
|
||||||
|
if err != nil {
|
||||||
|
errStr := err.Error()
|
||||||
|
if strings.Contains(errStr, "authorization_pending") {
|
||||||
|
fmt.Print(".")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.Contains(errStr, "slow_down") {
|
||||||
|
interval += 5 * time.Second
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Close browser on error before returning
|
||||||
|
browser.CloseBrowser()
|
||||||
|
return nil, fmt.Errorf("token creation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n\n✓ Authorization successful!")
|
||||||
|
|
||||||
|
// Close the browser window
|
||||||
|
if err := browser.CloseBrowser(); err != nil {
|
||||||
|
log.Debugf("Failed to close browser: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 5: Get profile ARN from CodeWhisperer API
|
||||||
|
fmt.Println("Fetching profile information...")
|
||||||
|
profileArn := c.fetchProfileArn(ctx, tokenResp.AccessToken)
|
||||||
|
|
||||||
|
// Extract email from JWT access token
|
||||||
|
email := ExtractEmailFromJWT(tokenResp.AccessToken)
|
||||||
|
if email != "" {
|
||||||
|
fmt.Printf(" Logged in as: %s\n", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||||
|
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ProfileArn: profileArn,
|
||||||
|
ExpiresAt: expiresAt.Format(time.RFC3339),
|
||||||
|
AuthMethod: "builder-id",
|
||||||
|
Provider: "AWS",
|
||||||
|
ClientID: regResp.ClientID,
|
||||||
|
ClientSecret: regResp.ClientSecret,
|
||||||
|
Email: email,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close browser on timeout for better UX
|
||||||
|
if err := browser.CloseBrowser(); err != nil {
|
||||||
|
log.Debugf("Failed to close browser on timeout: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("authorization timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchProfileArn retrieves the profile ARN from CodeWhisperer API.
|
||||||
|
// This is needed for file naming since AWS SSO OIDC doesn't return profile info.
|
||||||
|
func (c *SSOOIDCClient) fetchProfileArn(ctx context.Context, accessToken string) string {
|
||||||
|
// Try ListProfiles API first
|
||||||
|
profileArn := c.tryListProfiles(ctx, accessToken)
|
||||||
|
if profileArn != "" {
|
||||||
|
return profileArn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: Try ListAvailableCustomizations
|
||||||
|
return c.tryListCustomizations(ctx, accessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SSOOIDCClient) tryListProfiles(ctx context.Context, accessToken string) string {
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"origin": "AI_EDITOR",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||||
|
req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListProfiles")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("ListProfiles failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("ListProfiles response: %s", string(respBody))
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Profiles []struct {
|
||||||
|
Arn string `json:"arn"`
|
||||||
|
} `json:"profiles"`
|
||||||
|
ProfileArn string `json:"profileArn"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.ProfileArn != "" {
|
||||||
|
return result.ProfileArn
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Profiles) > 0 {
|
||||||
|
return result.Profiles[0].Arn
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SSOOIDCClient) tryListCustomizations(ctx context.Context, accessToken string) string {
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"origin": "AI_EDITOR",
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "https://codewhisperer.us-east-1.amazonaws.com", strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/x-amz-json-1.0")
|
||||||
|
req.Header.Set("x-amz-target", "AmazonCodeWhispererService.ListAvailableCustomizations")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
log.Debugf("ListAvailableCustomizations failed (status %d): %s", resp.StatusCode, string(respBody))
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("ListAvailableCustomizations response: %s", string(respBody))
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Customizations []struct {
|
||||||
|
Arn string `json:"arn"`
|
||||||
|
} `json:"customizations"`
|
||||||
|
ProfileArn string `json:"profileArn"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.ProfileArn != "" {
|
||||||
|
return result.ProfileArn
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Customizations) > 0 {
|
||||||
|
return result.Customizations[0].Arn
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
72
internal/auth/kiro/token.go
Normal file
72
internal/auth/kiro/token.go
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
package kiro
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KiroTokenStorage holds the persistent token data for Kiro authentication.
|
||||||
|
type KiroTokenStorage struct {
|
||||||
|
// AccessToken is the OAuth2 access token for API access
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
// RefreshToken is used to obtain new access tokens
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
// ProfileArn is the AWS CodeWhisperer profile ARN
|
||||||
|
ProfileArn string `json:"profile_arn"`
|
||||||
|
// ExpiresAt is the timestamp when the token expires
|
||||||
|
ExpiresAt string `json:"expires_at"`
|
||||||
|
// AuthMethod indicates the authentication method used
|
||||||
|
AuthMethod string `json:"auth_method"`
|
||||||
|
// Provider indicates the OAuth provider
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
// LastRefresh is the timestamp of the last token refresh
|
||||||
|
LastRefresh string `json:"last_refresh"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveTokenToFile persists the token storage to the specified file path.
|
||||||
|
func (s *KiroTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||||
|
dir := filepath.Dir(authFilePath)
|
||||||
|
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||||
|
return fmt.Errorf("failed to create directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(s, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal token storage: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(authFilePath, data, 0600); err != nil {
|
||||||
|
return fmt.Errorf("failed to write token file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadFromFile loads token storage from the specified file path.
|
||||||
|
func LoadFromFile(authFilePath string) (*KiroTokenStorage, error) {
|
||||||
|
data, err := os.ReadFile(authFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read token file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var storage KiroTokenStorage
|
||||||
|
if err := json.Unmarshal(data, &storage); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &storage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToTokenData converts storage to KiroTokenData for API use.
|
||||||
|
func (s *KiroTokenStorage) ToTokenData() *KiroTokenData {
|
||||||
|
return &KiroTokenData{
|
||||||
|
AccessToken: s.AccessToken,
|
||||||
|
RefreshToken: s.RefreshToken,
|
||||||
|
ProfileArn: s.ProfileArn,
|
||||||
|
ExpiresAt: s.ExpiresAt,
|
||||||
|
AuthMethod: s.AuthMethod,
|
||||||
|
Provider: s.Provider,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,14 +6,49 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
pkgbrowser "github.com/pkg/browser"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/skratchdot/open-golang/open"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// incognitoMode controls whether to open URLs in incognito/private mode.
|
||||||
|
// This is useful for OAuth flows where you want to use a different account.
|
||||||
|
var incognitoMode bool
|
||||||
|
|
||||||
|
// lastBrowserProcess stores the last opened browser process for cleanup
|
||||||
|
var lastBrowserProcess *exec.Cmd
|
||||||
|
var browserMutex sync.Mutex
|
||||||
|
|
||||||
|
// SetIncognitoMode enables or disables incognito/private browsing mode.
|
||||||
|
func SetIncognitoMode(enabled bool) {
|
||||||
|
incognitoMode = enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsIncognitoMode returns whether incognito mode is enabled.
|
||||||
|
func IsIncognitoMode() bool {
|
||||||
|
return incognitoMode
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseBrowser closes the last opened browser process.
|
||||||
|
func CloseBrowser() error {
|
||||||
|
browserMutex.Lock()
|
||||||
|
defer browserMutex.Unlock()
|
||||||
|
|
||||||
|
if lastBrowserProcess == nil || lastBrowserProcess.Process == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := lastBrowserProcess.Process.Kill()
|
||||||
|
lastBrowserProcess = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// OpenURL opens the specified URL in the default web browser.
|
// OpenURL opens the specified URL in the default web browser.
|
||||||
// It first attempts to use a platform-agnostic library and falls back to
|
// It uses the pkg/browser library which provides robust cross-platform support
|
||||||
// platform-specific commands if that fails.
|
// for Windows, macOS, and Linux.
|
||||||
|
// If incognito mode is enabled, it will open in a private/incognito window.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - url: The URL to open.
|
// - url: The URL to open.
|
||||||
@@ -21,16 +56,22 @@ import (
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - An error if the URL cannot be opened, otherwise nil.
|
// - An error if the URL cannot be opened, otherwise nil.
|
||||||
func OpenURL(url string) error {
|
func OpenURL(url string) error {
|
||||||
fmt.Printf("Attempting to open URL in browser: %s\n", url)
|
log.Debugf("Opening URL in browser: %s (incognito=%v)", url, incognitoMode)
|
||||||
|
|
||||||
// Try using the open-golang library first
|
// If incognito mode is enabled, use platform-specific incognito commands
|
||||||
err := open.Run(url)
|
if incognitoMode {
|
||||||
|
log.Debug("Using incognito mode")
|
||||||
|
return openURLIncognito(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use pkg/browser for cross-platform support
|
||||||
|
err := pkgbrowser.OpenURL(url)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
log.Debug("Successfully opened URL using open-golang library")
|
log.Debug("Successfully opened URL using pkg/browser library")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("open-golang failed: %v, trying platform-specific commands", err)
|
log.Debugf("pkg/browser failed: %v, trying platform-specific commands", err)
|
||||||
|
|
||||||
// Fallback to platform-specific commands
|
// Fallback to platform-specific commands
|
||||||
return openURLPlatformSpecific(url)
|
return openURLPlatformSpecific(url)
|
||||||
@@ -78,18 +119,379 @@ func openURLPlatformSpecific(url string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// openURLIncognito opens a URL in incognito/private browsing mode.
|
||||||
|
// It first tries to detect the default browser and use its incognito flag.
|
||||||
|
// Falls back to a chain of known browsers if detection fails.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - url: The URL to open.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - An error if the URL cannot be opened, otherwise nil.
|
||||||
|
func openURLIncognito(url string) error {
|
||||||
|
// First, try to detect and use the default browser
|
||||||
|
if cmd := tryDefaultBrowserIncognito(url); cmd != nil {
|
||||||
|
log.Debugf("Using detected default browser: %s %v", cmd.Path, cmd.Args[1:])
|
||||||
|
if err := cmd.Start(); err == nil {
|
||||||
|
storeBrowserProcess(cmd)
|
||||||
|
log.Debug("Successfully opened URL in default browser's incognito mode")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.Debugf("Failed to start default browser, trying fallback chain")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to known browser chain
|
||||||
|
cmd := tryFallbackBrowsersIncognito(url)
|
||||||
|
if cmd == nil {
|
||||||
|
log.Warn("No browser with incognito support found, falling back to normal mode")
|
||||||
|
return openURLPlatformSpecific(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Running incognito command: %s %v", cmd.Path, cmd.Args[1:])
|
||||||
|
err := cmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to open incognito browser: %v, falling back to normal mode", err)
|
||||||
|
return openURLPlatformSpecific(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
storeBrowserProcess(cmd)
|
||||||
|
log.Debug("Successfully opened URL in incognito/private mode")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// storeBrowserProcess safely stores the browser process for later cleanup.
|
||||||
|
func storeBrowserProcess(cmd *exec.Cmd) {
|
||||||
|
browserMutex.Lock()
|
||||||
|
lastBrowserProcess = cmd
|
||||||
|
browserMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryDefaultBrowserIncognito attempts to detect the default browser and return
|
||||||
|
// an exec.Cmd configured with the appropriate incognito flag.
|
||||||
|
func tryDefaultBrowserIncognito(url string) *exec.Cmd {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
return tryDefaultBrowserMacOS(url)
|
||||||
|
case "windows":
|
||||||
|
return tryDefaultBrowserWindows(url)
|
||||||
|
case "linux":
|
||||||
|
return tryDefaultBrowserLinux(url)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryDefaultBrowserMacOS detects the default browser on macOS.
|
||||||
|
func tryDefaultBrowserMacOS(url string) *exec.Cmd {
|
||||||
|
// Try to get default browser from Launch Services
|
||||||
|
out, err := exec.Command("defaults", "read", "com.apple.LaunchServices/com.apple.launchservices.secure", "LSHandlers").Output()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
output := string(out)
|
||||||
|
var browserName string
|
||||||
|
|
||||||
|
// Parse the output to find the http/https handler
|
||||||
|
if containsBrowserID(output, "com.google.chrome") {
|
||||||
|
browserName = "chrome"
|
||||||
|
} else if containsBrowserID(output, "org.mozilla.firefox") {
|
||||||
|
browserName = "firefox"
|
||||||
|
} else if containsBrowserID(output, "com.apple.safari") {
|
||||||
|
browserName = "safari"
|
||||||
|
} else if containsBrowserID(output, "com.brave.browser") {
|
||||||
|
browserName = "brave"
|
||||||
|
} else if containsBrowserID(output, "com.microsoft.edgemac") {
|
||||||
|
browserName = "edge"
|
||||||
|
}
|
||||||
|
|
||||||
|
return createMacOSIncognitoCmd(browserName, url)
|
||||||
|
}
|
||||||
|
|
||||||
|
// containsBrowserID checks if the LaunchServices output contains a browser ID.
|
||||||
|
func containsBrowserID(output, bundleID string) bool {
|
||||||
|
return strings.Contains(output, bundleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// createMacOSIncognitoCmd creates the appropriate incognito command for macOS browsers.
|
||||||
|
func createMacOSIncognitoCmd(browserName, url string) *exec.Cmd {
|
||||||
|
switch browserName {
|
||||||
|
case "chrome":
|
||||||
|
// Try direct path first
|
||||||
|
chromePath := "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
|
||||||
|
if _, err := exec.LookPath(chromePath); err == nil {
|
||||||
|
return exec.Command(chromePath, "--incognito", url)
|
||||||
|
}
|
||||||
|
return exec.Command("open", "-na", "Google Chrome", "--args", "--incognito", url)
|
||||||
|
case "firefox":
|
||||||
|
return exec.Command("open", "-na", "Firefox", "--args", "--private-window", url)
|
||||||
|
case "safari":
|
||||||
|
// Safari doesn't have CLI incognito, try AppleScript
|
||||||
|
return tryAppleScriptSafariPrivate(url)
|
||||||
|
case "brave":
|
||||||
|
return exec.Command("open", "-na", "Brave Browser", "--args", "--incognito", url)
|
||||||
|
case "edge":
|
||||||
|
return exec.Command("open", "-na", "Microsoft Edge", "--args", "--inprivate", url)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryAppleScriptSafariPrivate attempts to open Safari in private browsing mode using AppleScript.
|
||||||
|
func tryAppleScriptSafariPrivate(url string) *exec.Cmd {
|
||||||
|
// AppleScript to open a new private window in Safari
|
||||||
|
script := fmt.Sprintf(`
|
||||||
|
tell application "Safari"
|
||||||
|
activate
|
||||||
|
tell application "System Events"
|
||||||
|
keystroke "n" using {command down, shift down}
|
||||||
|
delay 0.5
|
||||||
|
end tell
|
||||||
|
set URL of document 1 to "%s"
|
||||||
|
end tell
|
||||||
|
`, url)
|
||||||
|
|
||||||
|
cmd := exec.Command("osascript", "-e", script)
|
||||||
|
// Test if this approach works by checking if Safari is available
|
||||||
|
if _, err := exec.LookPath("/Applications/Safari.app/Contents/MacOS/Safari"); err != nil {
|
||||||
|
log.Debug("Safari not found, AppleScript private window not available")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.Debug("Attempting Safari private window via AppleScript")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryDefaultBrowserWindows detects the default browser on Windows via registry.
|
||||||
|
func tryDefaultBrowserWindows(url string) *exec.Cmd {
|
||||||
|
// Query registry for default browser
|
||||||
|
out, err := exec.Command("reg", "query",
|
||||||
|
`HKEY_CURRENT_USER\Software\Microsoft\Windows\Shell\Associations\UrlAssociations\http\UserChoice`,
|
||||||
|
"/v", "ProgId").Output()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
output := string(out)
|
||||||
|
var browserName string
|
||||||
|
|
||||||
|
// Map ProgId to browser name
|
||||||
|
if strings.Contains(output, "ChromeHTML") {
|
||||||
|
browserName = "chrome"
|
||||||
|
} else if strings.Contains(output, "FirefoxURL") {
|
||||||
|
browserName = "firefox"
|
||||||
|
} else if strings.Contains(output, "MSEdgeHTM") {
|
||||||
|
browserName = "edge"
|
||||||
|
} else if strings.Contains(output, "BraveHTML") {
|
||||||
|
browserName = "brave"
|
||||||
|
}
|
||||||
|
|
||||||
|
return createWindowsIncognitoCmd(browserName, url)
|
||||||
|
}
|
||||||
|
|
||||||
|
// createWindowsIncognitoCmd creates the appropriate incognito command for Windows browsers.
|
||||||
|
func createWindowsIncognitoCmd(browserName, url string) *exec.Cmd {
|
||||||
|
switch browserName {
|
||||||
|
case "chrome":
|
||||||
|
paths := []string{
|
||||||
|
"chrome",
|
||||||
|
`C:\Program Files\Google\Chrome\Application\chrome.exe`,
|
||||||
|
`C:\Program Files (x86)\Google\Chrome\Application\chrome.exe`,
|
||||||
|
}
|
||||||
|
for _, p := range paths {
|
||||||
|
if _, err := exec.LookPath(p); err == nil {
|
||||||
|
return exec.Command(p, "--incognito", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "firefox":
|
||||||
|
if path, err := exec.LookPath("firefox"); err == nil {
|
||||||
|
return exec.Command(path, "--private-window", url)
|
||||||
|
}
|
||||||
|
case "edge":
|
||||||
|
paths := []string{
|
||||||
|
"msedge",
|
||||||
|
`C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe`,
|
||||||
|
`C:\Program Files\Microsoft\Edge\Application\msedge.exe`,
|
||||||
|
}
|
||||||
|
for _, p := range paths {
|
||||||
|
if _, err := exec.LookPath(p); err == nil {
|
||||||
|
return exec.Command(p, "--inprivate", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "brave":
|
||||||
|
paths := []string{
|
||||||
|
`C:\Program Files\BraveSoftware\Brave-Browser\Application\brave.exe`,
|
||||||
|
`C:\Program Files (x86)\BraveSoftware\Brave-Browser\Application\brave.exe`,
|
||||||
|
}
|
||||||
|
for _, p := range paths {
|
||||||
|
if _, err := exec.LookPath(p); err == nil {
|
||||||
|
return exec.Command(p, "--incognito", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryDefaultBrowserLinux detects the default browser on Linux using xdg-settings.
|
||||||
|
func tryDefaultBrowserLinux(url string) *exec.Cmd {
|
||||||
|
out, err := exec.Command("xdg-settings", "get", "default-web-browser").Output()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
desktop := string(out)
|
||||||
|
var browserName string
|
||||||
|
|
||||||
|
// Map .desktop file to browser name
|
||||||
|
if strings.Contains(desktop, "google-chrome") || strings.Contains(desktop, "chrome") {
|
||||||
|
browserName = "chrome"
|
||||||
|
} else if strings.Contains(desktop, "firefox") {
|
||||||
|
browserName = "firefox"
|
||||||
|
} else if strings.Contains(desktop, "chromium") {
|
||||||
|
browserName = "chromium"
|
||||||
|
} else if strings.Contains(desktop, "brave") {
|
||||||
|
browserName = "brave"
|
||||||
|
} else if strings.Contains(desktop, "microsoft-edge") || strings.Contains(desktop, "msedge") {
|
||||||
|
browserName = "edge"
|
||||||
|
}
|
||||||
|
|
||||||
|
return createLinuxIncognitoCmd(browserName, url)
|
||||||
|
}
|
||||||
|
|
||||||
|
// createLinuxIncognitoCmd creates the appropriate incognito command for Linux browsers.
|
||||||
|
func createLinuxIncognitoCmd(browserName, url string) *exec.Cmd {
|
||||||
|
switch browserName {
|
||||||
|
case "chrome":
|
||||||
|
paths := []string{"google-chrome", "google-chrome-stable"}
|
||||||
|
for _, p := range paths {
|
||||||
|
if path, err := exec.LookPath(p); err == nil {
|
||||||
|
return exec.Command(path, "--incognito", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "firefox":
|
||||||
|
paths := []string{"firefox", "firefox-esr"}
|
||||||
|
for _, p := range paths {
|
||||||
|
if path, err := exec.LookPath(p); err == nil {
|
||||||
|
return exec.Command(path, "--private-window", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "chromium":
|
||||||
|
paths := []string{"chromium", "chromium-browser"}
|
||||||
|
for _, p := range paths {
|
||||||
|
if path, err := exec.LookPath(p); err == nil {
|
||||||
|
return exec.Command(path, "--incognito", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "brave":
|
||||||
|
if path, err := exec.LookPath("brave-browser"); err == nil {
|
||||||
|
return exec.Command(path, "--incognito", url)
|
||||||
|
}
|
||||||
|
case "edge":
|
||||||
|
if path, err := exec.LookPath("microsoft-edge"); err == nil {
|
||||||
|
return exec.Command(path, "--inprivate", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryFallbackBrowsersIncognito tries a chain of known browsers as fallback.
|
||||||
|
func tryFallbackBrowsersIncognito(url string) *exec.Cmd {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
return tryFallbackBrowsersMacOS(url)
|
||||||
|
case "windows":
|
||||||
|
return tryFallbackBrowsersWindows(url)
|
||||||
|
case "linux":
|
||||||
|
return tryFallbackBrowsersLinuxChain(url)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryFallbackBrowsersMacOS tries known browsers on macOS.
|
||||||
|
func tryFallbackBrowsersMacOS(url string) *exec.Cmd {
|
||||||
|
// Try Chrome
|
||||||
|
chromePath := "/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"
|
||||||
|
if _, err := exec.LookPath(chromePath); err == nil {
|
||||||
|
return exec.Command(chromePath, "--incognito", url)
|
||||||
|
}
|
||||||
|
// Try Firefox
|
||||||
|
if _, err := exec.LookPath("/Applications/Firefox.app/Contents/MacOS/firefox"); err == nil {
|
||||||
|
return exec.Command("open", "-na", "Firefox", "--args", "--private-window", url)
|
||||||
|
}
|
||||||
|
// Try Brave
|
||||||
|
if _, err := exec.LookPath("/Applications/Brave Browser.app/Contents/MacOS/Brave Browser"); err == nil {
|
||||||
|
return exec.Command("open", "-na", "Brave Browser", "--args", "--incognito", url)
|
||||||
|
}
|
||||||
|
// Try Edge
|
||||||
|
if _, err := exec.LookPath("/Applications/Microsoft Edge.app/Contents/MacOS/Microsoft Edge"); err == nil {
|
||||||
|
return exec.Command("open", "-na", "Microsoft Edge", "--args", "--inprivate", url)
|
||||||
|
}
|
||||||
|
// Last resort: try Safari with AppleScript
|
||||||
|
if cmd := tryAppleScriptSafariPrivate(url); cmd != nil {
|
||||||
|
log.Info("Using Safari with AppleScript for private browsing (may require accessibility permissions)")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryFallbackBrowsersWindows tries known browsers on Windows.
|
||||||
|
func tryFallbackBrowsersWindows(url string) *exec.Cmd {
|
||||||
|
// Chrome
|
||||||
|
chromePaths := []string{
|
||||||
|
"chrome",
|
||||||
|
`C:\Program Files\Google\Chrome\Application\chrome.exe`,
|
||||||
|
`C:\Program Files (x86)\Google\Chrome\Application\chrome.exe`,
|
||||||
|
}
|
||||||
|
for _, p := range chromePaths {
|
||||||
|
if _, err := exec.LookPath(p); err == nil {
|
||||||
|
return exec.Command(p, "--incognito", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Firefox
|
||||||
|
if path, err := exec.LookPath("firefox"); err == nil {
|
||||||
|
return exec.Command(path, "--private-window", url)
|
||||||
|
}
|
||||||
|
// Edge (usually available on Windows 10+)
|
||||||
|
edgePaths := []string{
|
||||||
|
"msedge",
|
||||||
|
`C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe`,
|
||||||
|
`C:\Program Files\Microsoft\Edge\Application\msedge.exe`,
|
||||||
|
}
|
||||||
|
for _, p := range edgePaths {
|
||||||
|
if _, err := exec.LookPath(p); err == nil {
|
||||||
|
return exec.Command(p, "--inprivate", url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryFallbackBrowsersLinuxChain tries known browsers on Linux.
|
||||||
|
func tryFallbackBrowsersLinuxChain(url string) *exec.Cmd {
|
||||||
|
type browserConfig struct {
|
||||||
|
name string
|
||||||
|
flag string
|
||||||
|
}
|
||||||
|
browsers := []browserConfig{
|
||||||
|
{"google-chrome", "--incognito"},
|
||||||
|
{"google-chrome-stable", "--incognito"},
|
||||||
|
{"chromium", "--incognito"},
|
||||||
|
{"chromium-browser", "--incognito"},
|
||||||
|
{"firefox", "--private-window"},
|
||||||
|
{"firefox-esr", "--private-window"},
|
||||||
|
{"brave-browser", "--incognito"},
|
||||||
|
{"microsoft-edge", "--inprivate"},
|
||||||
|
}
|
||||||
|
for _, b := range browsers {
|
||||||
|
if path, err := exec.LookPath(b.name); err == nil {
|
||||||
|
return exec.Command(path, b.flag, url)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// IsAvailable checks if the system has a command available to open a web browser.
|
// IsAvailable checks if the system has a command available to open a web browser.
|
||||||
// It verifies the presence of necessary commands for the current operating system.
|
// It verifies the presence of necessary commands for the current operating system.
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - true if a browser can be opened, false otherwise.
|
// - true if a browser can be opened, false otherwise.
|
||||||
func IsAvailable() bool {
|
func IsAvailable() bool {
|
||||||
// First check if open-golang can work
|
|
||||||
testErr := open.Run("about:blank")
|
|
||||||
if testErr == nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check platform-specific commands
|
// Check platform-specific commands
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "darwin":
|
case "darwin":
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
|
|
||||||
// newAuthManager creates a new authentication manager instance with all supported
|
// newAuthManager creates a new authentication manager instance with all supported
|
||||||
// authenticators and a file-based token store. It initializes authenticators for
|
// authenticators and a file-based token store. It initializes authenticators for
|
||||||
// Gemini, Codex, Claude, and Qwen providers.
|
// Gemini, Codex, Claude, Qwen, IFlow, Antigravity, and GitHub Copilot providers.
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - *sdkAuth.Manager: A configured authentication manager instance
|
// - *sdkAuth.Manager: A configured authentication manager instance
|
||||||
@@ -19,6 +19,8 @@ func newAuthManager() *sdkAuth.Manager {
|
|||||||
sdkAuth.NewQwenAuthenticator(),
|
sdkAuth.NewQwenAuthenticator(),
|
||||||
sdkAuth.NewIFlowAuthenticator(),
|
sdkAuth.NewIFlowAuthenticator(),
|
||||||
sdkAuth.NewAntigravityAuthenticator(),
|
sdkAuth.NewAntigravityAuthenticator(),
|
||||||
|
sdkAuth.NewKiroAuthenticator(),
|
||||||
|
sdkAuth.NewGitHubCopilotAuthenticator(),
|
||||||
)
|
)
|
||||||
return manager
|
return manager
|
||||||
}
|
}
|
||||||
|
|||||||
44
internal/cmd/github_copilot_login.go
Normal file
44
internal/cmd/github_copilot_login.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoGitHubCopilotLogin triggers the OAuth device flow for GitHub Copilot and saves tokens.
|
||||||
|
// It initiates the device flow authentication, displays the user code for the user to enter
|
||||||
|
// at GitHub's verification URL, and waits for authorization before saving the tokens.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration containing proxy and auth directory settings
|
||||||
|
// - options: Login options including browser behavior settings
|
||||||
|
func DoGitHubCopilotLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
Metadata: map[string]string{},
|
||||||
|
Prompt: options.Prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
record, savedPath, err := manager.Login(context.Background(), "github-copilot", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("GitHub Copilot authentication failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
if record != nil && record.Label != "" {
|
||||||
|
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||||
|
}
|
||||||
|
fmt.Println("GitHub Copilot authentication successful!")
|
||||||
|
}
|
||||||
160
internal/cmd/kiro_login.go
Normal file
160
internal/cmd/kiro_login.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoKiroLogin triggers the Kiro authentication flow with Google OAuth.
|
||||||
|
// This is the default login method (same as --kiro-google-login).
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration
|
||||||
|
// - options: Login options including Prompt field
|
||||||
|
func DoKiroLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
// Use Google login as default
|
||||||
|
DoKiroGoogleLogin(cfg, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoKiroGoogleLogin triggers Kiro authentication with Google OAuth.
|
||||||
|
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration
|
||||||
|
// - options: Login options including prompts
|
||||||
|
func DoKiroGoogleLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: Kiro defaults to incognito mode for multi-account support.
|
||||||
|
// Users can override with --no-incognito if they want to use existing browser sessions.
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
|
||||||
|
// Use KiroAuthenticator with Google login
|
||||||
|
authenticator := sdkAuth.NewKiroAuthenticator()
|
||||||
|
record, err := authenticator.LoginWithGoogle(context.Background(), cfg, &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
Metadata: map[string]string{},
|
||||||
|
Prompt: options.Prompt,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Kiro Google authentication failed: %v", err)
|
||||||
|
fmt.Println("\nTroubleshooting:")
|
||||||
|
fmt.Println("1. Make sure the protocol handler is installed")
|
||||||
|
fmt.Println("2. Complete the Google login in the browser")
|
||||||
|
fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the auth record
|
||||||
|
savedPath, err := manager.SaveAuth(record, cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to save auth: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
if record != nil && record.Label != "" {
|
||||||
|
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||||
|
}
|
||||||
|
fmt.Println("Kiro Google authentication successful!")
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoKiroAWSLogin triggers Kiro authentication with AWS Builder ID.
|
||||||
|
// This uses the device code flow for AWS SSO OIDC authentication.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration
|
||||||
|
// - options: Login options including prompts
|
||||||
|
func DoKiroAWSLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: Kiro defaults to incognito mode for multi-account support.
|
||||||
|
// Users can override with --no-incognito if they want to use existing browser sessions.
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
|
||||||
|
// Use KiroAuthenticator with AWS Builder ID login (device code flow)
|
||||||
|
authenticator := sdkAuth.NewKiroAuthenticator()
|
||||||
|
record, err := authenticator.Login(context.Background(), cfg, &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
Metadata: map[string]string{},
|
||||||
|
Prompt: options.Prompt,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Kiro AWS authentication failed: %v", err)
|
||||||
|
fmt.Println("\nTroubleshooting:")
|
||||||
|
fmt.Println("1. Make sure you have an AWS Builder ID")
|
||||||
|
fmt.Println("2. Complete the authorization in the browser")
|
||||||
|
fmt.Println("3. If callback fails, try: --kiro-import (after logging in via Kiro IDE)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the auth record
|
||||||
|
savedPath, err := manager.SaveAuth(record, cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to save auth: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
if record != nil && record.Label != "" {
|
||||||
|
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||||
|
}
|
||||||
|
fmt.Println("Kiro AWS authentication successful!")
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoKiroImport imports Kiro token from Kiro IDE's token file.
|
||||||
|
// This is useful for users who have already logged in via Kiro IDE
|
||||||
|
// and want to use the same credentials in CLI Proxy API.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration
|
||||||
|
// - options: Login options (currently unused for import)
|
||||||
|
func DoKiroImport(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
|
||||||
|
// Use ImportFromKiroIDE instead of Login
|
||||||
|
authenticator := sdkAuth.NewKiroAuthenticator()
|
||||||
|
record, err := authenticator.ImportFromKiroIDE(context.Background(), cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Kiro token import failed: %v", err)
|
||||||
|
fmt.Println("\nMake sure you have logged in to Kiro IDE first:")
|
||||||
|
fmt.Println("1. Open Kiro IDE")
|
||||||
|
fmt.Println("2. Click 'Sign in with Google' (or GitHub)")
|
||||||
|
fmt.Println("3. Complete the login process")
|
||||||
|
fmt.Println("4. Run this command again")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the imported auth record
|
||||||
|
savedPath, err := manager.SaveAuth(record, cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to save auth: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
if record != nil && record.Label != "" {
|
||||||
|
fmt.Printf("Imported as %s\n", record.Label)
|
||||||
|
}
|
||||||
|
fmt.Println("Kiro token import successful!")
|
||||||
|
}
|
||||||
@@ -63,6 +63,13 @@ type Config struct {
|
|||||||
// GeminiKey defines Gemini API key configurations with optional routing overrides.
|
// GeminiKey defines Gemini API key configurations with optional routing overrides.
|
||||||
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
|
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
|
||||||
|
|
||||||
|
// KiroKey defines a list of Kiro (AWS CodeWhisperer) configurations.
|
||||||
|
KiroKey []KiroKey `yaml:"kiro" json:"kiro"`
|
||||||
|
|
||||||
|
// KiroPreferredEndpoint sets the global default preferred endpoint for all Kiro providers.
|
||||||
|
// Values: "ide" (default, CodeWhisperer) or "cli" (Amazon Q).
|
||||||
|
KiroPreferredEndpoint string `yaml:"kiro-preferred-endpoint" json:"kiro-preferred-endpoint"`
|
||||||
|
|
||||||
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
|
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
|
||||||
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
|
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
|
||||||
|
|
||||||
@@ -85,6 +92,11 @@ type Config struct {
|
|||||||
// Payload defines default and override rules for provider payload parameters.
|
// Payload defines default and override rules for provider payload parameters.
|
||||||
Payload PayloadConfig `yaml:"payload" json:"payload"`
|
Payload PayloadConfig `yaml:"payload" json:"payload"`
|
||||||
|
|
||||||
|
// IncognitoBrowser enables opening OAuth URLs in incognito/private browsing mode.
|
||||||
|
// This is useful when you want to login with a different account without logging out
|
||||||
|
// from your current session. Default: false.
|
||||||
|
IncognitoBrowser bool `yaml:"incognito-browser" json:"incognito-browser"`
|
||||||
|
|
||||||
legacyMigrationPending bool `yaml:"-" json:"-"`
|
legacyMigrationPending bool `yaml:"-" json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -252,6 +264,35 @@ type GeminiKey struct {
|
|||||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// KiroKey represents the configuration for Kiro (AWS CodeWhisperer) authentication.
|
||||||
|
type KiroKey struct {
|
||||||
|
// TokenFile is the path to the Kiro token file (default: ~/.aws/sso/cache/kiro-auth-token.json)
|
||||||
|
TokenFile string `yaml:"token-file,omitempty" json:"token-file,omitempty"`
|
||||||
|
|
||||||
|
// AccessToken is the OAuth access token for direct configuration.
|
||||||
|
AccessToken string `yaml:"access-token,omitempty" json:"access-token,omitempty"`
|
||||||
|
|
||||||
|
// RefreshToken is the OAuth refresh token for token renewal.
|
||||||
|
RefreshToken string `yaml:"refresh-token,omitempty" json:"refresh-token,omitempty"`
|
||||||
|
|
||||||
|
// ProfileArn is the AWS CodeWhisperer profile ARN.
|
||||||
|
ProfileArn string `yaml:"profile-arn,omitempty" json:"profile-arn,omitempty"`
|
||||||
|
|
||||||
|
// Region is the AWS region (default: us-east-1).
|
||||||
|
Region string `yaml:"region,omitempty" json:"region,omitempty"`
|
||||||
|
|
||||||
|
// ProxyURL optionally overrides the global proxy for this configuration.
|
||||||
|
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
|
||||||
|
|
||||||
|
// AgentTaskType sets the Kiro API task type. Known values: "vibe", "dev", "chat".
|
||||||
|
// Leave empty to let API use defaults. Different values may inject different system prompts.
|
||||||
|
AgentTaskType string `yaml:"agent-task-type,omitempty" json:"agent-task-type,omitempty"`
|
||||||
|
|
||||||
|
// PreferredEndpoint sets the preferred Kiro API endpoint/quota.
|
||||||
|
// Values: "codewhisperer" (default, IDE quota) or "amazonq" (CLI quota).
|
||||||
|
PreferredEndpoint string `yaml:"preferred-endpoint,omitempty" json:"preferred-endpoint,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// OpenAICompatibility represents the configuration for OpenAI API compatibility
|
// OpenAICompatibility represents the configuration for OpenAI API compatibility
|
||||||
// with external providers, allowing model aliases to be routed through OpenAI API format.
|
// with external providers, allowing model aliases to be routed through OpenAI API format.
|
||||||
type OpenAICompatibility struct {
|
type OpenAICompatibility struct {
|
||||||
@@ -334,6 +375,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
cfg.DisableCooling = false
|
cfg.DisableCooling = false
|
||||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||||
|
cfg.IncognitoBrowser = false // Default to normal browser (AWS uses incognito by force)
|
||||||
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
||||||
if optional {
|
if optional {
|
||||||
// In cloud deploy mode, if YAML parsing fails, return empty config instead of error.
|
// In cloud deploy mode, if YAML parsing fails, return empty config instead of error.
|
||||||
@@ -389,6 +431,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
// Sanitize Claude key headers
|
// Sanitize Claude key headers
|
||||||
cfg.SanitizeClaudeKeys()
|
cfg.SanitizeClaudeKeys()
|
||||||
|
|
||||||
|
// Sanitize Kiro keys: trim whitespace from credential fields
|
||||||
|
cfg.SanitizeKiroKeys()
|
||||||
|
|
||||||
// Sanitize OpenAI compatibility providers: drop entries without base-url
|
// Sanitize OpenAI compatibility providers: drop entries without base-url
|
||||||
cfg.SanitizeOpenAICompatibility()
|
cfg.SanitizeOpenAICompatibility()
|
||||||
|
|
||||||
@@ -465,6 +510,23 @@ func (cfg *Config) SanitizeClaudeKeys() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SanitizeKiroKeys trims whitespace from Kiro credential fields.
|
||||||
|
func (cfg *Config) SanitizeKiroKeys() {
|
||||||
|
if cfg == nil || len(cfg.KiroKey) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := range cfg.KiroKey {
|
||||||
|
entry := &cfg.KiroKey[i]
|
||||||
|
entry.TokenFile = strings.TrimSpace(entry.TokenFile)
|
||||||
|
entry.AccessToken = strings.TrimSpace(entry.AccessToken)
|
||||||
|
entry.RefreshToken = strings.TrimSpace(entry.RefreshToken)
|
||||||
|
entry.ProfileArn = strings.TrimSpace(entry.ProfileArn)
|
||||||
|
entry.Region = strings.TrimSpace(entry.Region)
|
||||||
|
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||||
|
entry.PreferredEndpoint = strings.TrimSpace(entry.PreferredEndpoint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SanitizeGeminiKeys deduplicates and normalizes Gemini credentials.
|
// SanitizeGeminiKeys deduplicates and normalizes Gemini credentials.
|
||||||
func (cfg *Config) SanitizeGeminiKeys() {
|
func (cfg *Config) SanitizeGeminiKeys() {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
|
|||||||
@@ -24,4 +24,7 @@ const (
|
|||||||
|
|
||||||
// Antigravity represents the Antigravity response format identifier.
|
// Antigravity represents the Antigravity response format identifier.
|
||||||
Antigravity = "antigravity"
|
Antigravity = "antigravity"
|
||||||
|
|
||||||
|
// Kiro represents the AWS CodeWhisperer (Kiro) provider identifier.
|
||||||
|
Kiro = "kiro"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -38,13 +38,16 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
|||||||
|
|
||||||
timestamp := entry.Time.Format("2006-01-02 15:04:05")
|
timestamp := entry.Time.Format("2006-01-02 15:04:05")
|
||||||
message := strings.TrimRight(entry.Message, "\r\n")
|
message := strings.TrimRight(entry.Message, "\r\n")
|
||||||
|
|
||||||
var formatted string
|
// Handle nil Caller (can happen with some log entries)
|
||||||
|
callerFile := "unknown"
|
||||||
|
callerLine := 0
|
||||||
if entry.Caller != nil {
|
if entry.Caller != nil {
|
||||||
formatted = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, filepath.Base(entry.Caller.File), entry.Caller.Line, message)
|
callerFile = filepath.Base(entry.Caller.File)
|
||||||
} else {
|
callerLine = entry.Caller.Line
|
||||||
formatted = fmt.Sprintf("[%s] [%s] %s\n", timestamp, entry.Level, message)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
formatted := fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, callerFile, callerLine, message)
|
||||||
buffer.WriteString(formatted)
|
buffer.WriteString(formatted)
|
||||||
|
|
||||||
return buffer.Bytes(), nil
|
return buffer.Bytes(), nil
|
||||||
@@ -55,6 +58,7 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
|||||||
func SetupBaseLogger() {
|
func SetupBaseLogger() {
|
||||||
setupOnce.Do(func() {
|
setupOnce.Do(func() {
|
||||||
log.SetOutput(os.Stdout)
|
log.SetOutput(os.Stdout)
|
||||||
|
log.SetLevel(log.InfoLevel)
|
||||||
log.SetReportCaller(true)
|
log.SetReportCaller(true)
|
||||||
log.SetFormatter(&LogFormatter{})
|
log.SetFormatter(&LogFormatter{})
|
||||||
|
|
||||||
|
|||||||
@@ -697,3 +697,353 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
|||||||
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetGitHubCopilotModels returns the available models for GitHub Copilot.
|
||||||
|
// These models are available through the GitHub Copilot API at api.githubcopilot.com.
|
||||||
|
func GetGitHubCopilotModels() []*ModelInfo {
|
||||||
|
now := int64(1732752000) // 2024-11-27
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "gpt-4.1",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "GPT-4.1",
|
||||||
|
Description: "OpenAI GPT-4.1 via GitHub Copilot",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 16384,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "GPT-5",
|
||||||
|
Description: "OpenAI GPT-5 via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5-mini",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "GPT-5 Mini",
|
||||||
|
Description: "OpenAI GPT-5 Mini via GitHub Copilot",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 16384,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5-codex",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "GPT-5 Codex",
|
||||||
|
Description: "OpenAI GPT-5 Codex via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.1",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "GPT-5.1",
|
||||||
|
Description: "OpenAI GPT-5.1 via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.1-codex",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "GPT-5.1 Codex",
|
||||||
|
Description: "OpenAI GPT-5.1 Codex via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32768,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gpt-5.1-codex-mini",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "GPT-5.1 Codex Mini",
|
||||||
|
Description: "OpenAI GPT-5.1 Codex Mini via GitHub Copilot",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 16384,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-haiku-4.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "Claude Haiku 4.5",
|
||||||
|
Description: "Anthropic Claude Haiku 4.5 via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-opus-4.1",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "Claude Opus 4.1",
|
||||||
|
Description: "Anthropic Claude Opus 4.1 via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 32000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-opus-4.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "Claude Opus 4.5",
|
||||||
|
Description: "Anthropic Claude Opus 4.5 via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-sonnet-4",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "Claude Sonnet 4",
|
||||||
|
Description: "Anthropic Claude Sonnet 4 via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-sonnet-4.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "Claude Sonnet 4.5",
|
||||||
|
Description: "Anthropic Claude Sonnet 4.5 via GitHub Copilot",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-2.5-pro",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "Gemini 2.5 Pro",
|
||||||
|
Description: "Google Gemini 2.5 Pro via GitHub Copilot",
|
||||||
|
ContextLength: 1048576,
|
||||||
|
MaxCompletionTokens: 65536,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-3-pro",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "Gemini 3 Pro",
|
||||||
|
Description: "Google Gemini 3 Pro via GitHub Copilot",
|
||||||
|
ContextLength: 1048576,
|
||||||
|
MaxCompletionTokens: 65536,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "grok-code-fast-1",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "Grok Code Fast 1",
|
||||||
|
Description: "xAI Grok Code Fast 1 via GitHub Copilot",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 16384,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "raptor-mini",
|
||||||
|
Object: "model",
|
||||||
|
Created: now,
|
||||||
|
OwnedBy: "github-copilot",
|
||||||
|
Type: "github-copilot",
|
||||||
|
DisplayName: "Raptor Mini",
|
||||||
|
Description: "Raptor Mini via GitHub Copilot",
|
||||||
|
ContextLength: 128000,
|
||||||
|
MaxCompletionTokens: 16384,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKiroModels returns the Kiro (AWS CodeWhisperer) model definitions
|
||||||
|
func GetKiroModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
// --- Base Models ---
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-opus-4-5",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Opus 4.5",
|
||||||
|
Description: "Claude Opus 4.5 via Kiro (2.2x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-sonnet-4-5",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Sonnet 4.5",
|
||||||
|
Description: "Claude Sonnet 4.5 via Kiro (1.3x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-sonnet-4",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Sonnet 4",
|
||||||
|
Description: "Claude Sonnet 4 via Kiro (1.3x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-haiku-4-5",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Haiku 4.5",
|
||||||
|
Description: "Claude Haiku 4.5 via Kiro (0.4x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
// --- Agentic Variants (Optimized for coding agents with chunked writes) ---
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-opus-4-5-agentic",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Opus 4.5 (Agentic)",
|
||||||
|
Description: "Claude Opus 4.5 optimized for coding agents (chunked writes)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-sonnet-4-5-agentic",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Sonnet 4.5 (Agentic)",
|
||||||
|
Description: "Claude Sonnet 4.5 optimized for coding agents (chunked writes)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-sonnet-4-agentic",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Sonnet 4 (Agentic)",
|
||||||
|
Description: "Claude Sonnet 4 optimized for coding agents (chunked writes)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "kiro-claude-haiku-4-5-agentic",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Kiro Claude Haiku 4.5 (Agentic)",
|
||||||
|
Description: "Claude Haiku 4.5 optimized for coding agents (chunked writes)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAmazonQModels returns the Amazon Q (AWS CodeWhisperer) model definitions.
|
||||||
|
// These models use the same API as Kiro and share the same executor.
|
||||||
|
func GetAmazonQModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "amazonq-auto",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro", // Uses Kiro executor - same API
|
||||||
|
DisplayName: "Amazon Q Auto",
|
||||||
|
Description: "Automatic model selection by Amazon Q",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "amazonq-claude-opus-4.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Amazon Q Claude Opus 4.5",
|
||||||
|
Description: "Claude Opus 4.5 via Amazon Q (2.2x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "amazonq-claude-sonnet-4.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Amazon Q Claude Sonnet 4.5",
|
||||||
|
Description: "Claude Sonnet 4.5 via Amazon Q (1.3x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "amazonq-claude-sonnet-4",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Amazon Q Claude Sonnet 4",
|
||||||
|
Description: "Claude Sonnet 4 via Amazon Q (1.3x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "amazonq-claude-haiku-4.5",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1732752000,
|
||||||
|
OwnedBy: "aws",
|
||||||
|
Type: "kiro",
|
||||||
|
DisplayName: "Amazon Q Claude Haiku 4.5",
|
||||||
|
Description: "Claude Haiku 4.5 via Amazon Q (0.4x credit)",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -766,7 +766,8 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
|||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
case "claude":
|
case "claude", "kiro", "antigravity":
|
||||||
|
// Claude, Kiro, and Antigravity all use Claude-compatible format for Claude Code client
|
||||||
result := map[string]any{
|
result := map[string]any{
|
||||||
"id": model.ID,
|
"id": model.ID,
|
||||||
"object": "model",
|
"object": "model",
|
||||||
@@ -781,6 +782,19 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
|||||||
if model.DisplayName != "" {
|
if model.DisplayName != "" {
|
||||||
result["display_name"] = model.DisplayName
|
result["display_name"] = model.DisplayName
|
||||||
}
|
}
|
||||||
|
// Add thinking support for Claude Code client
|
||||||
|
// Claude Code checks for "thinking" field (simple boolean) to enable tab toggle
|
||||||
|
// Also add "extended_thinking" for detailed budget info
|
||||||
|
if model.Thinking != nil {
|
||||||
|
result["thinking"] = true
|
||||||
|
result["extended_thinking"] = map[string]any{
|
||||||
|
"supported": true,
|
||||||
|
"min": model.Thinking.Min,
|
||||||
|
"max": model.Thinking.Max,
|
||||||
|
"zero_allowed": model.Thinking.ZeroAllowed,
|
||||||
|
"dynamic_allowed": model.Thinking.DynamicAllowed,
|
||||||
|
}
|
||||||
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
case "gemini":
|
case "gemini":
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -43,7 +44,10 @@ const (
|
|||||||
refreshSkew = 3000 * time.Second
|
refreshSkew = 3000 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
var randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
var (
|
||||||
|
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
randSourceMutex sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
// AntigravityExecutor proxies requests to the antigravity upstream.
|
// AntigravityExecutor proxies requests to the antigravity upstream.
|
||||||
type AntigravityExecutor struct {
|
type AntigravityExecutor struct {
|
||||||
@@ -777,15 +781,19 @@ func generateRequestID() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func generateSessionID() string {
|
func generateSessionID() string {
|
||||||
|
randSourceMutex.Lock()
|
||||||
n := randSource.Int63n(9_000_000_000_000_000_000)
|
n := randSource.Int63n(9_000_000_000_000_000_000)
|
||||||
|
randSourceMutex.Unlock()
|
||||||
return "-" + strconv.FormatInt(n, 10)
|
return "-" + strconv.FormatInt(n, 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateProjectID() string {
|
func generateProjectID() string {
|
||||||
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
|
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
|
||||||
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
|
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
|
||||||
|
randSourceMutex.Lock()
|
||||||
adj := adjectives[randSource.Intn(len(adjectives))]
|
adj := adjectives[randSource.Intn(len(adjectives))]
|
||||||
noun := nouns[randSource.Intn(len(nouns))]
|
noun := nouns[randSource.Intn(len(nouns))]
|
||||||
|
randSourceMutex.Unlock()
|
||||||
randomPart := strings.ToLower(uuid.NewString())[:5]
|
randomPart := strings.ToLower(uuid.NewString())[:5]
|
||||||
return adj + "-" + noun + "-" + randomPart
|
return adj + "-" + noun + "-" + randomPart
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,38 @@
|
|||||||
package executor
|
package executor
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
type codexCache struct {
|
type codexCache struct {
|
||||||
ID string
|
ID string
|
||||||
Expire time.Time
|
Expire time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
var codexCacheMap = map[string]codexCache{}
|
var (
|
||||||
|
codexCacheMap = map[string]codexCache{}
|
||||||
|
codexCacheMutex sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// getCodexCache safely retrieves a cache entry
|
||||||
|
func getCodexCache(key string) (codexCache, bool) {
|
||||||
|
codexCacheMutex.RLock()
|
||||||
|
defer codexCacheMutex.RUnlock()
|
||||||
|
cache, ok := codexCacheMap[key]
|
||||||
|
return cache, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// setCodexCache safely sets a cache entry
|
||||||
|
func setCodexCache(key string, cache codexCache) {
|
||||||
|
codexCacheMutex.Lock()
|
||||||
|
defer codexCacheMutex.Unlock()
|
||||||
|
codexCacheMap[key] = cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteCodexCache safely deletes a cache entry
|
||||||
|
func deleteCodexCache(key string) {
|
||||||
|
codexCacheMutex.Lock()
|
||||||
|
defer codexCacheMutex.Unlock()
|
||||||
|
delete(codexCacheMap, key)
|
||||||
|
}
|
||||||
|
|||||||
@@ -442,12 +442,12 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
|||||||
if userIDResult.Exists() {
|
if userIDResult.Exists() {
|
||||||
var hasKey bool
|
var hasKey bool
|
||||||
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
|
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
|
||||||
if cache, hasKey = codexCacheMap[key]; !hasKey || cache.Expire.Before(time.Now()) {
|
if cache, hasKey = getCodexCache(key); !hasKey || cache.Expire.Before(time.Now()) {
|
||||||
cache = codexCache{
|
cache = codexCache{
|
||||||
ID: uuid.New().String(),
|
ID: uuid.New().String(),
|
||||||
Expire: time.Now().Add(1 * time.Hour),
|
Expire: time.Now().Add(1 * time.Hour),
|
||||||
}
|
}
|
||||||
codexCacheMap[key] = cache
|
setCodexCache(key, cache)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if from == "openai-response" {
|
} else if from == "openai-response" {
|
||||||
|
|||||||
361
internal/runtime/executor/github_copilot_executor.go
Normal file
361
internal/runtime/executor/github_copilot_executor.go
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
copilotauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
githubCopilotBaseURL = "https://api.githubcopilot.com"
|
||||||
|
githubCopilotChatPath = "/chat/completions"
|
||||||
|
githubCopilotAuthType = "github-copilot"
|
||||||
|
githubCopilotTokenCacheTTL = 25 * time.Minute
|
||||||
|
// tokenExpiryBuffer is the time before expiry when we should refresh the token.
|
||||||
|
tokenExpiryBuffer = 5 * time.Minute
|
||||||
|
// maxScannerBufferSize is the maximum buffer size for SSE scanning (20MB).
|
||||||
|
maxScannerBufferSize = 20_971_520
|
||||||
|
|
||||||
|
// Copilot API header values.
|
||||||
|
copilotUserAgent = "GithubCopilot/1.0"
|
||||||
|
copilotEditorVersion = "vscode/1.100.0"
|
||||||
|
copilotPluginVersion = "copilot/1.300.0"
|
||||||
|
copilotIntegrationID = "vscode-chat"
|
||||||
|
copilotOpenAIIntent = "conversation-panel"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GitHubCopilotExecutor handles requests to the GitHub Copilot API.
|
||||||
|
type GitHubCopilotExecutor struct {
|
||||||
|
cfg *config.Config
|
||||||
|
mu sync.RWMutex
|
||||||
|
cache map[string]*cachedAPIToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// cachedAPIToken stores a cached Copilot API token with its expiry.
|
||||||
|
type cachedAPIToken struct {
|
||||||
|
token string
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGitHubCopilotExecutor constructs a new executor instance.
|
||||||
|
func NewGitHubCopilotExecutor(cfg *config.Config) *GitHubCopilotExecutor {
|
||||||
|
return &GitHubCopilotExecutor{
|
||||||
|
cfg: cfg,
|
||||||
|
cache: make(map[string]*cachedAPIToken),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identifier implements ProviderExecutor.
|
||||||
|
func (e *GitHubCopilotExecutor) Identifier() string { return githubCopilotAuthType }
|
||||||
|
|
||||||
|
// PrepareRequest implements ProviderExecutor.
|
||||||
|
func (e *GitHubCopilotExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute handles non-streaming requests to GitHub Copilot.
|
||||||
|
func (e *GitHubCopilotExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
|
apiToken, errToken := e.ensureAPIToken(ctx, auth)
|
||||||
|
if errToken != nil {
|
||||||
|
return resp, errToken
|
||||||
|
}
|
||||||
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("openai")
|
||||||
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
|
body = e.normalizeModel(req.Model, body)
|
||||||
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
|
body, _ = sjson.SetBytes(body, "stream", false)
|
||||||
|
|
||||||
|
url := githubCopilotBaseURL + githubCopilotChatPath
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
e.applyHeaders(httpReq, apiToken)
|
||||||
|
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: url,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: httpReq.Header.Clone(),
|
||||||
|
Body: body,
|
||||||
|
Provider: e.Identifier(),
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("github-copilot executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
|
||||||
|
if !isHTTPSuccess(httpResp.StatusCode) {
|
||||||
|
data, _ := io.ReadAll(httpResp.Body)
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
|
log.Debugf("github-copilot executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
|
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
|
|
||||||
|
detail := parseOpenAIUsage(data)
|
||||||
|
if detail.TotalTokens > 0 {
|
||||||
|
reporter.publish(ctx, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
var param any
|
||||||
|
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||||
|
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||||
|
reporter.ensurePublished(ctx)
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteStream handles streaming requests to GitHub Copilot.
|
||||||
|
func (e *GitHubCopilotExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
|
apiToken, errToken := e.ensureAPIToken(ctx, auth)
|
||||||
|
if errToken != nil {
|
||||||
|
return nil, errToken
|
||||||
|
}
|
||||||
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("openai")
|
||||||
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
|
body = e.normalizeModel(req.Model, body)
|
||||||
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||||
|
body, _ = sjson.SetBytes(body, "stream", true)
|
||||||
|
// Enable stream options for usage stats in stream
|
||||||
|
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||||
|
|
||||||
|
url := githubCopilotBaseURL + githubCopilotChatPath
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
e.applyHeaders(httpReq, apiToken)
|
||||||
|
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: url,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: httpReq.Header.Clone(),
|
||||||
|
Body: body,
|
||||||
|
Provider: e.Identifier(),
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
|
||||||
|
if !isHTTPSuccess(httpResp.StatusCode) {
|
||||||
|
data, readErr := io.ReadAll(httpResp.Body)
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("github-copilot executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
if readErr != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, readErr)
|
||||||
|
return nil, readErr
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
|
log.Debugf("github-copilot executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
|
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
|
stream = out
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(out)
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("github-copilot executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(httpResp.Body)
|
||||||
|
scanner.Buffer(nil, maxScannerBufferSize)
|
||||||
|
var param any
|
||||||
|
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
|
|
||||||
|
// Parse SSE data
|
||||||
|
if bytes.HasPrefix(line, dataTag) {
|
||||||
|
data := bytes.TrimSpace(line[5:])
|
||||||
|
if bytes.Equal(data, []byte("[DONE]")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||||
|
reporter.publish(ctx, detail)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||||
|
for i := range chunks {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
|
reporter.publishFailure(ctx)
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
|
} else {
|
||||||
|
reporter.ensurePublished(ctx)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountTokens is not supported for GitHub Copilot.
|
||||||
|
func (e *GitHubCopilotExecutor) CountTokens(_ context.Context, _ *cliproxyauth.Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported for github-copilot"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh validates the GitHub token is still working.
|
||||||
|
// GitHub OAuth tokens don't expire traditionally, so we just validate.
|
||||||
|
func (e *GitHubCopilotExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
|
if auth == nil {
|
||||||
|
return nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the GitHub access token
|
||||||
|
accessToken := metaStringValue(auth.Metadata, "access_token")
|
||||||
|
if accessToken == "" {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the token can still get a Copilot API token
|
||||||
|
copilotAuth := copilotauth.NewCopilotAuth(e.cfg)
|
||||||
|
_, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("github-copilot token validation failed: %v", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureAPIToken gets or refreshes the Copilot API token.
|
||||||
|
func (e *GitHubCopilotExecutor) ensureAPIToken(ctx context.Context, auth *cliproxyauth.Auth) (string, error) {
|
||||||
|
if auth == nil {
|
||||||
|
return "", statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the GitHub access token
|
||||||
|
accessToken := metaStringValue(auth.Metadata, "access_token")
|
||||||
|
if accessToken == "" {
|
||||||
|
return "", statusErr{code: http.StatusUnauthorized, msg: "missing github access token"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for cached API token using thread-safe access
|
||||||
|
e.mu.RLock()
|
||||||
|
if cached, ok := e.cache[accessToken]; ok && cached.expiresAt.After(time.Now().Add(tokenExpiryBuffer)) {
|
||||||
|
e.mu.RUnlock()
|
||||||
|
return cached.token, nil
|
||||||
|
}
|
||||||
|
e.mu.RUnlock()
|
||||||
|
|
||||||
|
// Get a new Copilot API token
|
||||||
|
copilotAuth := copilotauth.NewCopilotAuth(e.cfg)
|
||||||
|
apiToken, err := copilotAuth.GetCopilotAPIToken(ctx, accessToken)
|
||||||
|
if err != nil {
|
||||||
|
return "", statusErr{code: http.StatusUnauthorized, msg: fmt.Sprintf("failed to get copilot api token: %v", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache the token with thread-safe access
|
||||||
|
expiresAt := time.Now().Add(githubCopilotTokenCacheTTL)
|
||||||
|
if apiToken.ExpiresAt > 0 {
|
||||||
|
expiresAt = time.Unix(apiToken.ExpiresAt, 0)
|
||||||
|
}
|
||||||
|
e.mu.Lock()
|
||||||
|
e.cache[accessToken] = &cachedAPIToken{
|
||||||
|
token: apiToken.Token,
|
||||||
|
expiresAt: expiresAt,
|
||||||
|
}
|
||||||
|
e.mu.Unlock()
|
||||||
|
|
||||||
|
return apiToken.Token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyHeaders sets the required headers for GitHub Copilot API requests.
|
||||||
|
func (e *GitHubCopilotExecutor) applyHeaders(r *http.Request, apiToken string) {
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
r.Header.Set("Authorization", "Bearer "+apiToken)
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
r.Header.Set("User-Agent", copilotUserAgent)
|
||||||
|
r.Header.Set("Editor-Version", copilotEditorVersion)
|
||||||
|
r.Header.Set("Editor-Plugin-Version", copilotPluginVersion)
|
||||||
|
r.Header.Set("Openai-Intent", copilotOpenAIIntent)
|
||||||
|
r.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
|
||||||
|
r.Header.Set("X-Request-Id", uuid.NewString())
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeModel is a no-op as GitHub Copilot accepts model names directly.
|
||||||
|
// Model mapping should be done at the registry level if needed.
|
||||||
|
func (e *GitHubCopilotExecutor) normalizeModel(_ string, body []byte) []byte {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// isHTTPSuccess checks if the status code indicates success (2xx).
|
||||||
|
func isHTTPSuccess(statusCode int) bool {
|
||||||
|
return statusCode >= 200 && statusCode < 300
|
||||||
|
}
|
||||||
3195
internal/runtime/executor/kiro_executor.go
Normal file
3195
internal/runtime/executor/kiro_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
@@ -14,11 +15,19 @@ import (
|
|||||||
"golang.org/x/net/proxy"
|
"golang.org/x/net/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// httpClientCache caches HTTP clients by proxy URL to enable connection reuse
|
||||||
|
var (
|
||||||
|
httpClientCache = make(map[string]*http.Client)
|
||||||
|
httpClientCacheMutex sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
|
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
|
||||||
// 1. Use auth.ProxyURL if configured (highest priority)
|
// 1. Use auth.ProxyURL if configured (highest priority)
|
||||||
// 2. Use cfg.ProxyURL if auth proxy is not configured
|
// 2. Use cfg.ProxyURL if auth proxy is not configured
|
||||||
// 3. Use RoundTripper from context if neither are configured
|
// 3. Use RoundTripper from context if neither are configured
|
||||||
//
|
//
|
||||||
|
// This function caches HTTP clients by proxy URL to enable TCP/TLS connection reuse.
|
||||||
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - ctx: The context containing optional RoundTripper
|
// - ctx: The context containing optional RoundTripper
|
||||||
// - cfg: The application configuration
|
// - cfg: The application configuration
|
||||||
@@ -28,11 +37,6 @@ import (
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - *http.Client: An HTTP client with configured proxy or transport
|
// - *http.Client: An HTTP client with configured proxy or transport
|
||||||
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||||
httpClient := &http.Client{}
|
|
||||||
if timeout > 0 {
|
|
||||||
httpClient.Timeout = timeout
|
|
||||||
}
|
|
||||||
|
|
||||||
// Priority 1: Use auth.ProxyURL if configured
|
// Priority 1: Use auth.ProxyURL if configured
|
||||||
var proxyURL string
|
var proxyURL string
|
||||||
if auth != nil {
|
if auth != nil {
|
||||||
@@ -44,11 +48,39 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
|||||||
proxyURL = strings.TrimSpace(cfg.ProxyURL)
|
proxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build cache key from proxy URL (empty string for no proxy)
|
||||||
|
cacheKey := proxyURL
|
||||||
|
|
||||||
|
// Check cache first
|
||||||
|
httpClientCacheMutex.RLock()
|
||||||
|
if cachedClient, ok := httpClientCache[cacheKey]; ok {
|
||||||
|
httpClientCacheMutex.RUnlock()
|
||||||
|
// Return a wrapper with the requested timeout but shared transport
|
||||||
|
if timeout > 0 {
|
||||||
|
return &http.Client{
|
||||||
|
Transport: cachedClient.Transport,
|
||||||
|
Timeout: timeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cachedClient
|
||||||
|
}
|
||||||
|
httpClientCacheMutex.RUnlock()
|
||||||
|
|
||||||
|
// Create new client
|
||||||
|
httpClient := &http.Client{}
|
||||||
|
if timeout > 0 {
|
||||||
|
httpClient.Timeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
// If we have a proxy URL configured, set up the transport
|
// If we have a proxy URL configured, set up the transport
|
||||||
if proxyURL != "" {
|
if proxyURL != "" {
|
||||||
transport := buildProxyTransport(proxyURL)
|
transport := buildProxyTransport(proxyURL)
|
||||||
if transport != nil {
|
if transport != nil {
|
||||||
httpClient.Transport = transport
|
httpClient.Transport = transport
|
||||||
|
// Cache the client
|
||||||
|
httpClientCacheMutex.Lock()
|
||||||
|
httpClientCache[cacheKey] = httpClient
|
||||||
|
httpClientCacheMutex.Unlock()
|
||||||
return httpClient
|
return httpClient
|
||||||
}
|
}
|
||||||
// If proxy setup failed, log and fall through to context RoundTripper
|
// If proxy setup failed, log and fall through to context RoundTripper
|
||||||
@@ -60,6 +92,13 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
|||||||
httpClient.Transport = rt
|
httpClient.Transport = rt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cache the client for no-proxy case
|
||||||
|
if proxyURL == "" {
|
||||||
|
httpClientCacheMutex.Lock()
|
||||||
|
httpClientCache[cacheKey] = httpClient
|
||||||
|
httpClientCacheMutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
return httpClient
|
return httpClient
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,43 +2,107 @@ package executor
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tiktoken-go/tokenizer"
|
"github.com/tiktoken-go/tokenizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// tokenizerCache stores tokenizer instances to avoid repeated creation
|
||||||
|
var tokenizerCache sync.Map
|
||||||
|
|
||||||
|
// TokenizerWrapper wraps a tokenizer codec with an adjustment factor for models
|
||||||
|
// where tiktoken may not accurately estimate token counts (e.g., Claude models)
|
||||||
|
type TokenizerWrapper struct {
|
||||||
|
Codec tokenizer.Codec
|
||||||
|
AdjustmentFactor float64 // 1.0 means no adjustment, >1.0 means tiktoken underestimates
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the token count with adjustment factor applied
|
||||||
|
func (tw *TokenizerWrapper) Count(text string) (int, error) {
|
||||||
|
count, err := tw.Codec.Count(text)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if tw.AdjustmentFactor != 1.0 && tw.AdjustmentFactor > 0 {
|
||||||
|
return int(float64(count) * tw.AdjustmentFactor), nil
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTokenizer returns a cached tokenizer for the given model.
|
||||||
|
// This improves performance by avoiding repeated tokenizer creation.
|
||||||
|
func getTokenizer(model string) (*TokenizerWrapper, error) {
|
||||||
|
// Check cache first
|
||||||
|
if cached, ok := tokenizerCache.Load(model); ok {
|
||||||
|
return cached.(*TokenizerWrapper), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache miss, create new tokenizer
|
||||||
|
wrapper, err := tokenizerForModel(model)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store in cache (use LoadOrStore to handle race conditions)
|
||||||
|
actual, _ := tokenizerCache.LoadOrStore(model, wrapper)
|
||||||
|
return actual.(*TokenizerWrapper), nil
|
||||||
|
}
|
||||||
|
|
||||||
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
|
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
|
||||||
func tokenizerForModel(model string) (tokenizer.Codec, error) {
|
// For Claude models, applies a 1.1 adjustment factor since tiktoken may underestimate.
|
||||||
|
func tokenizerForModel(model string) (*TokenizerWrapper, error) {
|
||||||
sanitized := strings.ToLower(strings.TrimSpace(model))
|
sanitized := strings.ToLower(strings.TrimSpace(model))
|
||||||
|
|
||||||
|
// Claude models use cl100k_base with 1.1 adjustment factor
|
||||||
|
// because tiktoken may underestimate Claude's actual token count
|
||||||
|
if strings.Contains(sanitized, "claude") || strings.HasPrefix(sanitized, "kiro-") || strings.HasPrefix(sanitized, "amazonq-") {
|
||||||
|
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.1}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var enc tokenizer.Codec
|
||||||
|
var err error
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case sanitized == "":
|
case sanitized == "":
|
||||||
return tokenizer.Get(tokenizer.Cl100kBase)
|
enc, err = tokenizer.Get(tokenizer.Cl100kBase)
|
||||||
case strings.HasPrefix(sanitized, "gpt-5"):
|
case strings.HasPrefix(sanitized, "gpt-5"):
|
||||||
return tokenizer.ForModel(tokenizer.GPT5)
|
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||||
case strings.HasPrefix(sanitized, "gpt-5.1"):
|
case strings.HasPrefix(sanitized, "gpt-5.1"):
|
||||||
return tokenizer.ForModel(tokenizer.GPT5)
|
enc, err = tokenizer.ForModel(tokenizer.GPT5)
|
||||||
case strings.HasPrefix(sanitized, "gpt-4.1"):
|
case strings.HasPrefix(sanitized, "gpt-4.1"):
|
||||||
return tokenizer.ForModel(tokenizer.GPT41)
|
enc, err = tokenizer.ForModel(tokenizer.GPT41)
|
||||||
case strings.HasPrefix(sanitized, "gpt-4o"):
|
case strings.HasPrefix(sanitized, "gpt-4o"):
|
||||||
return tokenizer.ForModel(tokenizer.GPT4o)
|
enc, err = tokenizer.ForModel(tokenizer.GPT4o)
|
||||||
case strings.HasPrefix(sanitized, "gpt-4"):
|
case strings.HasPrefix(sanitized, "gpt-4"):
|
||||||
return tokenizer.ForModel(tokenizer.GPT4)
|
enc, err = tokenizer.ForModel(tokenizer.GPT4)
|
||||||
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
|
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
|
||||||
return tokenizer.ForModel(tokenizer.GPT35Turbo)
|
enc, err = tokenizer.ForModel(tokenizer.GPT35Turbo)
|
||||||
case strings.HasPrefix(sanitized, "o1"):
|
case strings.HasPrefix(sanitized, "o1"):
|
||||||
return tokenizer.ForModel(tokenizer.O1)
|
enc, err = tokenizer.ForModel(tokenizer.O1)
|
||||||
case strings.HasPrefix(sanitized, "o3"):
|
case strings.HasPrefix(sanitized, "o3"):
|
||||||
return tokenizer.ForModel(tokenizer.O3)
|
enc, err = tokenizer.ForModel(tokenizer.O3)
|
||||||
case strings.HasPrefix(sanitized, "o4"):
|
case strings.HasPrefix(sanitized, "o4"):
|
||||||
return tokenizer.ForModel(tokenizer.O4Mini)
|
enc, err = tokenizer.ForModel(tokenizer.O4Mini)
|
||||||
default:
|
default:
|
||||||
return tokenizer.Get(tokenizer.O200kBase)
|
enc, err = tokenizer.Get(tokenizer.O200kBase)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &TokenizerWrapper{Codec: enc, AdjustmentFactor: 1.0}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
|
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
|
||||||
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
func countOpenAIChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
|
||||||
if enc == nil {
|
if enc == nil {
|
||||||
return 0, fmt.Errorf("encoder is nil")
|
return 0, fmt.Errorf("encoder is nil")
|
||||||
}
|
}
|
||||||
@@ -62,11 +126,206 @@ func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Count text tokens
|
||||||
count, err := enc.Count(joined)
|
count, err := enc.Count(joined)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return int64(count), nil
|
|
||||||
|
// Extract and add image tokens from placeholders
|
||||||
|
imageTokens := extractImageTokens(joined)
|
||||||
|
|
||||||
|
return int64(count) + int64(imageTokens), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// countClaudeChatTokens approximates prompt tokens for Claude API chat completions payloads.
|
||||||
|
// This handles Claude's message format with system, messages, and tools.
|
||||||
|
// Image tokens are estimated based on image dimensions when available.
|
||||||
|
func countClaudeChatTokens(enc *TokenizerWrapper, payload []byte) (int64, error) {
|
||||||
|
if enc == nil {
|
||||||
|
return 0, fmt.Errorf("encoder is nil")
|
||||||
|
}
|
||||||
|
if len(payload) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
root := gjson.ParseBytes(payload)
|
||||||
|
segments := make([]string, 0, 32)
|
||||||
|
|
||||||
|
// Collect system prompt (can be string or array of content blocks)
|
||||||
|
collectClaudeSystem(root.Get("system"), &segments)
|
||||||
|
|
||||||
|
// Collect messages
|
||||||
|
collectClaudeMessages(root.Get("messages"), &segments)
|
||||||
|
|
||||||
|
// Collect tools
|
||||||
|
collectClaudeTools(root.Get("tools"), &segments)
|
||||||
|
|
||||||
|
joined := strings.TrimSpace(strings.Join(segments, "\n"))
|
||||||
|
if joined == "" {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count text tokens
|
||||||
|
count, err := enc.Count(joined)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and add image tokens from placeholders
|
||||||
|
imageTokens := extractImageTokens(joined)
|
||||||
|
|
||||||
|
return int64(count) + int64(imageTokens), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// imageTokenPattern matches [IMAGE:xxx tokens] format for extracting estimated image tokens
|
||||||
|
var imageTokenPattern = regexp.MustCompile(`\[IMAGE:(\d+) tokens\]`)
|
||||||
|
|
||||||
|
// extractImageTokens extracts image token estimates from placeholder text.
|
||||||
|
// Placeholders are in the format [IMAGE:xxx tokens] where xxx is the estimated token count.
|
||||||
|
func extractImageTokens(text string) int {
|
||||||
|
matches := imageTokenPattern.FindAllStringSubmatch(text, -1)
|
||||||
|
total := 0
|
||||||
|
for _, match := range matches {
|
||||||
|
if len(match) > 1 {
|
||||||
|
if tokens, err := strconv.Atoi(match[1]); err == nil {
|
||||||
|
total += tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
// estimateImageTokens calculates estimated tokens for an image based on dimensions.
|
||||||
|
// Based on Claude's image token calculation: tokens ≈ (width * height) / 750
|
||||||
|
// Minimum 85 tokens, maximum 1590 tokens (for 1568x1568 images).
|
||||||
|
func estimateImageTokens(width, height float64) int {
|
||||||
|
if width <= 0 || height <= 0 {
|
||||||
|
// No valid dimensions, use default estimate (medium-sized image)
|
||||||
|
return 1000
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := int(width * height / 750)
|
||||||
|
|
||||||
|
// Apply bounds
|
||||||
|
if tokens < 85 {
|
||||||
|
tokens = 85
|
||||||
|
}
|
||||||
|
if tokens > 1590 {
|
||||||
|
tokens = 1590
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectClaudeSystem extracts text from Claude's system field.
|
||||||
|
// System can be a string or an array of content blocks.
|
||||||
|
func collectClaudeSystem(system gjson.Result, segments *[]string) {
|
||||||
|
if !system.Exists() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if system.Type == gjson.String {
|
||||||
|
addIfNotEmpty(segments, system.String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if system.IsArray() {
|
||||||
|
system.ForEach(func(_, block gjson.Result) bool {
|
||||||
|
blockType := block.Get("type").String()
|
||||||
|
if blockType == "text" || blockType == "" {
|
||||||
|
addIfNotEmpty(segments, block.Get("text").String())
|
||||||
|
}
|
||||||
|
// Also handle plain string blocks
|
||||||
|
if block.Type == gjson.String {
|
||||||
|
addIfNotEmpty(segments, block.String())
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectClaudeMessages extracts text from Claude's messages array.
|
||||||
|
func collectClaudeMessages(messages gjson.Result, segments *[]string) {
|
||||||
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
messages.ForEach(func(_, message gjson.Result) bool {
|
||||||
|
addIfNotEmpty(segments, message.Get("role").String())
|
||||||
|
collectClaudeContent(message.Get("content"), segments)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectClaudeContent extracts text from Claude's content field.
|
||||||
|
// Content can be a string or an array of content blocks.
|
||||||
|
// For images, estimates token count based on dimensions when available.
|
||||||
|
func collectClaudeContent(content gjson.Result, segments *[]string) {
|
||||||
|
if !content.Exists() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
addIfNotEmpty(segments, content.String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if content.IsArray() {
|
||||||
|
content.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
partType := part.Get("type").String()
|
||||||
|
switch partType {
|
||||||
|
case "text":
|
||||||
|
addIfNotEmpty(segments, part.Get("text").String())
|
||||||
|
case "image":
|
||||||
|
// Estimate image tokens based on dimensions if available
|
||||||
|
source := part.Get("source")
|
||||||
|
if source.Exists() {
|
||||||
|
width := source.Get("width").Float()
|
||||||
|
height := source.Get("height").Float()
|
||||||
|
if width > 0 && height > 0 {
|
||||||
|
tokens := estimateImageTokens(width, height)
|
||||||
|
addIfNotEmpty(segments, fmt.Sprintf("[IMAGE:%d tokens]", tokens))
|
||||||
|
} else {
|
||||||
|
// No dimensions available, use default estimate
|
||||||
|
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No source info, use default estimate
|
||||||
|
addIfNotEmpty(segments, "[IMAGE:1000 tokens]")
|
||||||
|
}
|
||||||
|
case "tool_use":
|
||||||
|
addIfNotEmpty(segments, part.Get("id").String())
|
||||||
|
addIfNotEmpty(segments, part.Get("name").String())
|
||||||
|
if input := part.Get("input"); input.Exists() {
|
||||||
|
addIfNotEmpty(segments, input.Raw)
|
||||||
|
}
|
||||||
|
case "tool_result":
|
||||||
|
addIfNotEmpty(segments, part.Get("tool_use_id").String())
|
||||||
|
collectClaudeContent(part.Get("content"), segments)
|
||||||
|
case "thinking":
|
||||||
|
addIfNotEmpty(segments, part.Get("thinking").String())
|
||||||
|
default:
|
||||||
|
// For unknown types, try to extract any text content
|
||||||
|
if part.Type == gjson.String {
|
||||||
|
addIfNotEmpty(segments, part.String())
|
||||||
|
} else if part.Type == gjson.JSON {
|
||||||
|
addIfNotEmpty(segments, part.Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectClaudeTools extracts text from Claude's tools array.
|
||||||
|
func collectClaudeTools(tools gjson.Result, segments *[]string) {
|
||||||
|
if !tools.Exists() || !tools.IsArray() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||||
|
addIfNotEmpty(segments, tool.Get("name").String())
|
||||||
|
addIfNotEmpty(segments, tool.Get("description").String())
|
||||||
|
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
|
||||||
|
addIfNotEmpty(segments, inputSchema.Raw)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
|
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
|
||||||
|
|||||||
@@ -50,6 +50,10 @@ type ToolCallAccumulator struct {
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
|
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
|
||||||
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
|
||||||
|
var localParam any
|
||||||
|
if param == nil {
|
||||||
|
param = &localParam
|
||||||
|
}
|
||||||
if *param == nil {
|
if *param == nil {
|
||||||
*param = &ConvertAnthropicResponseToOpenAIParams{
|
*param = &ConvertAnthropicResponseToOpenAIParams{
|
||||||
CreatedAt: 0,
|
CreatedAt: 0,
|
||||||
|
|||||||
@@ -33,4 +33,7 @@ import (
|
|||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/chat-completions"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/chat-completions"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/responses"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/responses"
|
||||||
|
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/claude"
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/openai"
|
||||||
)
|
)
|
||||||
|
|||||||
20
internal/translator/kiro/claude/init.go
Normal file
20
internal/translator/kiro/claude/init.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
// Package claude provides translation between Kiro and Claude formats.
|
||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
translator.Register(
|
||||||
|
Claude,
|
||||||
|
Kiro,
|
||||||
|
ConvertClaudeRequestToKiro,
|
||||||
|
interfaces.TranslateResponse{
|
||||||
|
Stream: ConvertKiroStreamToClaude,
|
||||||
|
NonStream: ConvertKiroNonStreamToClaude,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
21
internal/translator/kiro/claude/kiro_claude.go
Normal file
21
internal/translator/kiro/claude/kiro_claude.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
// Package claude provides translation between Kiro and Claude formats.
|
||||||
|
// Since Kiro executor generates Claude-compatible SSE format internally (with event: prefix),
|
||||||
|
// translations are pass-through for streaming, but responses need proper formatting.
|
||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConvertKiroStreamToClaude converts Kiro streaming response to Claude format.
|
||||||
|
// Kiro executor already generates complete SSE format with "event:" prefix,
|
||||||
|
// so this is a simple pass-through.
|
||||||
|
func ConvertKiroStreamToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
|
||||||
|
return []string{string(rawResponse)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertKiroNonStreamToClaude converts Kiro non-streaming response to Claude format.
|
||||||
|
// The response is already in Claude format, so this is a pass-through.
|
||||||
|
func ConvertKiroNonStreamToClaude(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
|
||||||
|
return string(rawResponse)
|
||||||
|
}
|
||||||
774
internal/translator/kiro/claude/kiro_claude_request.go
Normal file
774
internal/translator/kiro/claude/kiro_claude_request.go
Normal file
@@ -0,0 +1,774 @@
|
|||||||
|
// Package claude provides request translation functionality for Claude API to Kiro format.
|
||||||
|
// It handles parsing and transforming Claude API requests into the Kiro/Amazon Q API format,
|
||||||
|
// extracting model information, system instructions, message contents, and tool declarations.
|
||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
// Kiro API request structs - field order determines JSON key order
|
||||||
|
|
||||||
|
// KiroPayload is the top-level request structure for Kiro API
|
||||||
|
type KiroPayload struct {
|
||||||
|
ConversationState KiroConversationState `json:"conversationState"`
|
||||||
|
ProfileArn string `json:"profileArn,omitempty"`
|
||||||
|
InferenceConfig *KiroInferenceConfig `json:"inferenceConfig,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroInferenceConfig contains inference parameters for the Kiro API.
|
||||||
|
type KiroInferenceConfig struct {
|
||||||
|
MaxTokens int `json:"maxTokens,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"topP,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroConversationState holds the conversation context
|
||||||
|
type KiroConversationState struct {
|
||||||
|
ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL" - must be first field
|
||||||
|
ConversationID string `json:"conversationId"`
|
||||||
|
CurrentMessage KiroCurrentMessage `json:"currentMessage"`
|
||||||
|
History []KiroHistoryMessage `json:"history,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroCurrentMessage wraps the current user message
|
||||||
|
type KiroCurrentMessage struct {
|
||||||
|
UserInputMessage KiroUserInputMessage `json:"userInputMessage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroHistoryMessage represents a message in the conversation history
|
||||||
|
type KiroHistoryMessage struct {
|
||||||
|
UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"`
|
||||||
|
AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroImage represents an image in Kiro API format
|
||||||
|
type KiroImage struct {
|
||||||
|
Format string `json:"format"`
|
||||||
|
Source KiroImageSource `json:"source"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroImageSource contains the image data
|
||||||
|
type KiroImageSource struct {
|
||||||
|
Bytes string `json:"bytes"` // base64 encoded image data
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroUserInputMessage represents a user message
|
||||||
|
type KiroUserInputMessage struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
ModelID string `json:"modelId"`
|
||||||
|
Origin string `json:"origin"`
|
||||||
|
Images []KiroImage `json:"images,omitempty"`
|
||||||
|
UserInputMessageContext *KiroUserInputMessageContext `json:"userInputMessageContext,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroUserInputMessageContext contains tool-related context
|
||||||
|
type KiroUserInputMessageContext struct {
|
||||||
|
ToolResults []KiroToolResult `json:"toolResults,omitempty"`
|
||||||
|
Tools []KiroToolWrapper `json:"tools,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroToolResult represents a tool execution result
|
||||||
|
type KiroToolResult struct {
|
||||||
|
Content []KiroTextContent `json:"content"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
ToolUseID string `json:"toolUseId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroTextContent represents text content
|
||||||
|
type KiroTextContent struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroToolWrapper wraps a tool specification
|
||||||
|
type KiroToolWrapper struct {
|
||||||
|
ToolSpecification KiroToolSpecification `json:"toolSpecification"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroToolSpecification defines a tool's schema
|
||||||
|
type KiroToolSpecification struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
InputSchema KiroInputSchema `json:"inputSchema"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroInputSchema wraps the JSON schema for tool input
|
||||||
|
type KiroInputSchema struct {
|
||||||
|
JSON interface{} `json:"json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroAssistantResponseMessage represents an assistant message
|
||||||
|
type KiroAssistantResponseMessage struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
ToolUses []KiroToolUse `json:"toolUses,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroToolUse represents a tool invocation by the assistant
|
||||||
|
type KiroToolUse struct {
|
||||||
|
ToolUseID string `json:"toolUseId"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Input map[string]interface{} `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertClaudeRequestToKiro converts a Claude API request to Kiro format.
|
||||||
|
// This is the main entry point for request translation.
|
||||||
|
func ConvertClaudeRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||||
|
// For Kiro, we pass through the Claude format since buildKiroPayload
|
||||||
|
// expects Claude format and does the conversion internally.
|
||||||
|
// The actual conversion happens in the executor when building the HTTP request.
|
||||||
|
return inputRawJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildKiroPayload constructs the Kiro API request payload from Claude format.
|
||||||
|
// Supports tool calling - tools are passed via userInputMessageContext.
|
||||||
|
// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE.
|
||||||
|
// isAgentic parameter enables chunked write optimization prompt for -agentic model variants.
|
||||||
|
// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode).
|
||||||
|
// Supports thinking mode - when Claude API thinking parameter is present, injects thinkingHint.
|
||||||
|
// Returns the payload and a boolean indicating whether thinking mode was injected.
|
||||||
|
func BuildKiroPayload(claudeBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) ([]byte, bool) {
|
||||||
|
// Extract max_tokens for potential use in inferenceConfig
|
||||||
|
// Handle -1 as "use maximum" (Kiro max output is ~32000 tokens)
|
||||||
|
const kiroMaxOutputTokens = 32000
|
||||||
|
var maxTokens int64
|
||||||
|
if mt := gjson.GetBytes(claudeBody, "max_tokens"); mt.Exists() {
|
||||||
|
maxTokens = mt.Int()
|
||||||
|
if maxTokens == -1 {
|
||||||
|
maxTokens = kiroMaxOutputTokens
|
||||||
|
log.Debugf("kiro: max_tokens=-1 converted to %d", kiroMaxOutputTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract temperature if specified
|
||||||
|
var temperature float64
|
||||||
|
var hasTemperature bool
|
||||||
|
if temp := gjson.GetBytes(claudeBody, "temperature"); temp.Exists() {
|
||||||
|
temperature = temp.Float()
|
||||||
|
hasTemperature = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract top_p if specified
|
||||||
|
var topP float64
|
||||||
|
var hasTopP bool
|
||||||
|
if tp := gjson.GetBytes(claudeBody, "top_p"); tp.Exists() {
|
||||||
|
topP = tp.Float()
|
||||||
|
hasTopP = true
|
||||||
|
log.Debugf("kiro: extracted top_p: %.2f", topP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize origin value for Kiro API compatibility
|
||||||
|
origin = normalizeOrigin(origin)
|
||||||
|
log.Debugf("kiro: normalized origin value: %s", origin)
|
||||||
|
|
||||||
|
messages := gjson.GetBytes(claudeBody, "messages")
|
||||||
|
|
||||||
|
// For chat-only mode, don't include tools
|
||||||
|
var tools gjson.Result
|
||||||
|
if !isChatOnly {
|
||||||
|
tools = gjson.GetBytes(claudeBody, "tools")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract system prompt
|
||||||
|
systemPrompt := extractSystemPrompt(claudeBody)
|
||||||
|
|
||||||
|
// Check for thinking mode using the comprehensive IsThinkingEnabled function
|
||||||
|
// This supports Claude API format, OpenAI reasoning_effort, and AMP/Cursor format
|
||||||
|
thinkingEnabled := IsThinkingEnabled(claudeBody)
|
||||||
|
_, budgetTokens := checkThinkingMode(claudeBody) // Get budget tokens from Claude format if available
|
||||||
|
if budgetTokens <= 0 {
|
||||||
|
// Calculate budgetTokens based on max_tokens if available
|
||||||
|
// Use 50% of max_tokens for thinking, with min 8000 and max 24000
|
||||||
|
if maxTokens > 0 {
|
||||||
|
budgetTokens = maxTokens / 2
|
||||||
|
if budgetTokens < 8000 {
|
||||||
|
budgetTokens = 8000
|
||||||
|
}
|
||||||
|
if budgetTokens > 24000 {
|
||||||
|
budgetTokens = 24000
|
||||||
|
}
|
||||||
|
log.Debugf("kiro: budgetTokens calculated from max_tokens: %d (max_tokens=%d)", budgetTokens, maxTokens)
|
||||||
|
} else {
|
||||||
|
budgetTokens = 16000 // Default budget tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject timestamp context
|
||||||
|
timestamp := time.Now().Format("2006-01-02 15:04:05 MST")
|
||||||
|
timestampContext := fmt.Sprintf("[Context: Current time is %s]", timestamp)
|
||||||
|
if systemPrompt != "" {
|
||||||
|
systemPrompt = timestampContext + "\n\n" + systemPrompt
|
||||||
|
} else {
|
||||||
|
systemPrompt = timestampContext
|
||||||
|
}
|
||||||
|
log.Debugf("kiro: injected timestamp context: %s", timestamp)
|
||||||
|
|
||||||
|
// Inject agentic optimization prompt for -agentic model variants
|
||||||
|
if isAgentic {
|
||||||
|
if systemPrompt != "" {
|
||||||
|
systemPrompt += "\n"
|
||||||
|
}
|
||||||
|
systemPrompt += kirocommon.KiroAgenticSystemPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle tool_choice parameter - Kiro doesn't support it natively, so we inject system prompt hints
|
||||||
|
// Claude tool_choice values: {"type": "auto/any/tool", "name": "..."}
|
||||||
|
toolChoiceHint := extractClaudeToolChoiceHint(claudeBody)
|
||||||
|
if toolChoiceHint != "" {
|
||||||
|
if systemPrompt != "" {
|
||||||
|
systemPrompt += "\n"
|
||||||
|
}
|
||||||
|
systemPrompt += toolChoiceHint
|
||||||
|
log.Debugf("kiro: injected tool_choice hint into system prompt")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject thinking hint when thinking mode is enabled
|
||||||
|
if thinkingEnabled {
|
||||||
|
if systemPrompt != "" {
|
||||||
|
systemPrompt += "\n"
|
||||||
|
}
|
||||||
|
dynamicThinkingHint := fmt.Sprintf("<thinking_mode>interleaved</thinking_mode><max_thinking_length>%d</max_thinking_length>", budgetTokens)
|
||||||
|
systemPrompt += dynamicThinkingHint
|
||||||
|
log.Debugf("kiro: injected dynamic thinking hint into system prompt, max_thinking_length: %d", budgetTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert Claude tools to Kiro format
|
||||||
|
kiroTools := convertClaudeToolsToKiro(tools)
|
||||||
|
|
||||||
|
// Process messages and build history
|
||||||
|
history, currentUserMsg, currentToolResults := processMessages(messages, modelID, origin)
|
||||||
|
|
||||||
|
// Build content with system prompt
|
||||||
|
if currentUserMsg != nil {
|
||||||
|
currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, systemPrompt, currentToolResults)
|
||||||
|
|
||||||
|
// Deduplicate currentToolResults
|
||||||
|
currentToolResults = deduplicateToolResults(currentToolResults)
|
||||||
|
|
||||||
|
// Build userInputMessageContext with tools and tool results
|
||||||
|
if len(kiroTools) > 0 || len(currentToolResults) > 0 {
|
||||||
|
currentUserMsg.UserInputMessageContext = &KiroUserInputMessageContext{
|
||||||
|
Tools: kiroTools,
|
||||||
|
ToolResults: currentToolResults,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build payload
|
||||||
|
var currentMessage KiroCurrentMessage
|
||||||
|
if currentUserMsg != nil {
|
||||||
|
currentMessage = KiroCurrentMessage{UserInputMessage: *currentUserMsg}
|
||||||
|
} else {
|
||||||
|
fallbackContent := ""
|
||||||
|
if systemPrompt != "" {
|
||||||
|
fallbackContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n"
|
||||||
|
}
|
||||||
|
currentMessage = KiroCurrentMessage{UserInputMessage: KiroUserInputMessage{
|
||||||
|
Content: fallbackContent,
|
||||||
|
ModelID: modelID,
|
||||||
|
Origin: origin,
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build inferenceConfig if we have any inference parameters
|
||||||
|
var inferenceConfig *KiroInferenceConfig
|
||||||
|
if maxTokens > 0 || hasTemperature || hasTopP {
|
||||||
|
inferenceConfig = &KiroInferenceConfig{}
|
||||||
|
if maxTokens > 0 {
|
||||||
|
inferenceConfig.MaxTokens = int(maxTokens)
|
||||||
|
}
|
||||||
|
if hasTemperature {
|
||||||
|
inferenceConfig.Temperature = temperature
|
||||||
|
}
|
||||||
|
if hasTopP {
|
||||||
|
inferenceConfig.TopP = topP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := KiroPayload{
|
||||||
|
ConversationState: KiroConversationState{
|
||||||
|
ChatTriggerType: "MANUAL",
|
||||||
|
ConversationID: uuid.New().String(),
|
||||||
|
CurrentMessage: currentMessage,
|
||||||
|
History: history,
|
||||||
|
},
|
||||||
|
ProfileArn: profileArn,
|
||||||
|
InferenceConfig: inferenceConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("kiro: failed to marshal payload: %v", err)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, thinkingEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeOrigin normalizes origin value for Kiro API compatibility
|
||||||
|
func normalizeOrigin(origin string) string {
|
||||||
|
switch origin {
|
||||||
|
case "KIRO_CLI":
|
||||||
|
return "CLI"
|
||||||
|
case "KIRO_AI_EDITOR":
|
||||||
|
return "AI_EDITOR"
|
||||||
|
case "AMAZON_Q":
|
||||||
|
return "CLI"
|
||||||
|
case "KIRO_IDE":
|
||||||
|
return "AI_EDITOR"
|
||||||
|
default:
|
||||||
|
return origin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractSystemPrompt extracts system prompt from Claude request
|
||||||
|
func extractSystemPrompt(claudeBody []byte) string {
|
||||||
|
systemField := gjson.GetBytes(claudeBody, "system")
|
||||||
|
if systemField.IsArray() {
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, block := range systemField.Array() {
|
||||||
|
if block.Get("type").String() == "text" {
|
||||||
|
sb.WriteString(block.Get("text").String())
|
||||||
|
} else if block.Type == gjson.String {
|
||||||
|
sb.WriteString(block.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
return systemField.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkThinkingMode checks if thinking mode is enabled in the Claude request
|
||||||
|
func checkThinkingMode(claudeBody []byte) (bool, int64) {
|
||||||
|
thinkingEnabled := false
|
||||||
|
var budgetTokens int64 = 16000
|
||||||
|
|
||||||
|
thinkingField := gjson.GetBytes(claudeBody, "thinking")
|
||||||
|
if thinkingField.Exists() {
|
||||||
|
thinkingType := thinkingField.Get("type").String()
|
||||||
|
if thinkingType == "enabled" {
|
||||||
|
thinkingEnabled = true
|
||||||
|
if bt := thinkingField.Get("budget_tokens"); bt.Exists() {
|
||||||
|
budgetTokens = bt.Int()
|
||||||
|
if budgetTokens <= 0 {
|
||||||
|
thinkingEnabled = false
|
||||||
|
log.Debugf("kiro: thinking mode disabled via budget_tokens <= 0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if thinkingEnabled {
|
||||||
|
log.Debugf("kiro: thinking mode enabled via Claude API parameter, budget_tokens: %d", budgetTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return thinkingEnabled, budgetTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsThinkingEnabled is a public wrapper to check if thinking mode is enabled.
|
||||||
|
// This is used by the executor to determine whether to parse <thinking> tags in responses.
|
||||||
|
// When thinking is NOT enabled in the request, <thinking> tags in responses should be
|
||||||
|
// treated as regular text content, not as thinking blocks.
|
||||||
|
//
|
||||||
|
// Supports multiple formats:
|
||||||
|
// - Claude API format: thinking.type = "enabled"
|
||||||
|
// - OpenAI format: reasoning_effort parameter
|
||||||
|
// - AMP/Cursor format: <thinking_mode>interleaved</thinking_mode> in system prompt
|
||||||
|
func IsThinkingEnabled(body []byte) bool {
|
||||||
|
// Check Claude API format first (thinking.type = "enabled")
|
||||||
|
enabled, _ := checkThinkingMode(body)
|
||||||
|
if enabled {
|
||||||
|
log.Debugf("kiro: IsThinkingEnabled returning true (Claude API format)")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check OpenAI format: reasoning_effort parameter
|
||||||
|
// Valid values: "low", "medium", "high", "auto" (not "none")
|
||||||
|
reasoningEffort := gjson.GetBytes(body, "reasoning_effort")
|
||||||
|
if reasoningEffort.Exists() {
|
||||||
|
effort := reasoningEffort.String()
|
||||||
|
if effort != "" && effort != "none" {
|
||||||
|
log.Debugf("kiro: thinking mode enabled via OpenAI reasoning_effort: %s", effort)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check AMP/Cursor format: <thinking_mode>interleaved</thinking_mode> in system prompt
|
||||||
|
// This is how AMP client passes thinking configuration
|
||||||
|
bodyStr := string(body)
|
||||||
|
if strings.Contains(bodyStr, "<thinking_mode>") && strings.Contains(bodyStr, "</thinking_mode>") {
|
||||||
|
// Extract thinking mode value
|
||||||
|
startTag := "<thinking_mode>"
|
||||||
|
endTag := "</thinking_mode>"
|
||||||
|
startIdx := strings.Index(bodyStr, startTag)
|
||||||
|
if startIdx >= 0 {
|
||||||
|
startIdx += len(startTag)
|
||||||
|
endIdx := strings.Index(bodyStr[startIdx:], endTag)
|
||||||
|
if endIdx >= 0 {
|
||||||
|
thinkingMode := bodyStr[startIdx : startIdx+endIdx]
|
||||||
|
if thinkingMode == "interleaved" || thinkingMode == "enabled" {
|
||||||
|
log.Debugf("kiro: thinking mode enabled via AMP/Cursor format: %s", thinkingMode)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check OpenAI format: max_completion_tokens with reasoning (o1-style)
|
||||||
|
// Some clients use this to indicate reasoning mode
|
||||||
|
if gjson.GetBytes(body, "max_completion_tokens").Exists() {
|
||||||
|
// If max_completion_tokens is set, check if model name suggests reasoning
|
||||||
|
model := gjson.GetBytes(body, "model").String()
|
||||||
|
if strings.Contains(strings.ToLower(model), "thinking") ||
|
||||||
|
strings.Contains(strings.ToLower(model), "reason") {
|
||||||
|
log.Debugf("kiro: thinking mode enabled via model name hint: %s", model)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("kiro: IsThinkingEnabled returning false (no thinking mode detected)")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// shortenToolNameIfNeeded shortens tool names that exceed 64 characters.
|
||||||
|
// MCP tools often have long names like "mcp__server-name__tool-name".
|
||||||
|
// This preserves the "mcp__" prefix and last segment when possible.
|
||||||
|
func shortenToolNameIfNeeded(name string) string {
|
||||||
|
const limit = 64
|
||||||
|
if len(name) <= limit {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
// For MCP tools, try to preserve prefix and last segment
|
||||||
|
if strings.HasPrefix(name, "mcp__") {
|
||||||
|
idx := strings.LastIndex(name, "__")
|
||||||
|
if idx > 0 {
|
||||||
|
cand := "mcp__" + name[idx+2:]
|
||||||
|
if len(cand) > limit {
|
||||||
|
return cand[:limit]
|
||||||
|
}
|
||||||
|
return cand
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return name[:limit]
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertClaudeToolsToKiro converts Claude tools to Kiro format
|
||||||
|
func convertClaudeToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
||||||
|
var kiroTools []KiroToolWrapper
|
||||||
|
if !tools.IsArray() {
|
||||||
|
return kiroTools
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tool := range tools.Array() {
|
||||||
|
name := tool.Get("name").String()
|
||||||
|
description := tool.Get("description").String()
|
||||||
|
inputSchema := tool.Get("input_schema").Value()
|
||||||
|
|
||||||
|
// Shorten tool name if it exceeds 64 characters (common with MCP tools)
|
||||||
|
originalName := name
|
||||||
|
name = shortenToolNameIfNeeded(name)
|
||||||
|
if name != originalName {
|
||||||
|
log.Debugf("kiro: shortened tool name from '%s' to '%s'", originalName, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CRITICAL FIX: Kiro API requires non-empty description
|
||||||
|
if strings.TrimSpace(description) == "" {
|
||||||
|
description = fmt.Sprintf("Tool: %s", name)
|
||||||
|
log.Debugf("kiro: tool '%s' has empty description, using default: %s", name, description)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Truncate long descriptions
|
||||||
|
if len(description) > kirocommon.KiroMaxToolDescLen {
|
||||||
|
truncLen := kirocommon.KiroMaxToolDescLen - 30
|
||||||
|
for truncLen > 0 && !utf8.RuneStart(description[truncLen]) {
|
||||||
|
truncLen--
|
||||||
|
}
|
||||||
|
description = description[:truncLen] + "... (description truncated)"
|
||||||
|
}
|
||||||
|
|
||||||
|
kiroTools = append(kiroTools, KiroToolWrapper{
|
||||||
|
ToolSpecification: KiroToolSpecification{
|
||||||
|
Name: name,
|
||||||
|
Description: description,
|
||||||
|
InputSchema: KiroInputSchema{JSON: inputSchema},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return kiroTools
|
||||||
|
}
|
||||||
|
|
||||||
|
// processMessages processes Claude messages and builds Kiro history
|
||||||
|
func processMessages(messages gjson.Result, modelID, origin string) ([]KiroHistoryMessage, *KiroUserInputMessage, []KiroToolResult) {
|
||||||
|
var history []KiroHistoryMessage
|
||||||
|
var currentUserMsg *KiroUserInputMessage
|
||||||
|
var currentToolResults []KiroToolResult
|
||||||
|
|
||||||
|
// Merge adjacent messages with the same role
|
||||||
|
messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
|
||||||
|
for i, msg := range messagesArray {
|
||||||
|
role := msg.Get("role").String()
|
||||||
|
isLastMessage := i == len(messagesArray)-1
|
||||||
|
|
||||||
|
if role == "user" {
|
||||||
|
userMsg, toolResults := BuildUserMessageStruct(msg, modelID, origin)
|
||||||
|
if isLastMessage {
|
||||||
|
currentUserMsg = &userMsg
|
||||||
|
currentToolResults = toolResults
|
||||||
|
} else {
|
||||||
|
// CRITICAL: Kiro API requires content to be non-empty for history messages too
|
||||||
|
if strings.TrimSpace(userMsg.Content) == "" {
|
||||||
|
if len(toolResults) > 0 {
|
||||||
|
userMsg.Content = "Tool results provided."
|
||||||
|
} else {
|
||||||
|
userMsg.Content = "Continue"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// For history messages, embed tool results in context
|
||||||
|
if len(toolResults) > 0 {
|
||||||
|
userMsg.UserInputMessageContext = &KiroUserInputMessageContext{
|
||||||
|
ToolResults: toolResults,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
history = append(history, KiroHistoryMessage{
|
||||||
|
UserInputMessage: &userMsg,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else if role == "assistant" {
|
||||||
|
assistantMsg := BuildAssistantMessageStruct(msg)
|
||||||
|
if isLastMessage {
|
||||||
|
history = append(history, KiroHistoryMessage{
|
||||||
|
AssistantResponseMessage: &assistantMsg,
|
||||||
|
})
|
||||||
|
// Create a "Continue" user message as currentMessage
|
||||||
|
currentUserMsg = &KiroUserInputMessage{
|
||||||
|
Content: "Continue",
|
||||||
|
ModelID: modelID,
|
||||||
|
Origin: origin,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
history = append(history, KiroHistoryMessage{
|
||||||
|
AssistantResponseMessage: &assistantMsg,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return history, currentUserMsg, currentToolResults
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildFinalContent builds the final content with system prompt
|
||||||
|
func buildFinalContent(content, systemPrompt string, toolResults []KiroToolResult) string {
|
||||||
|
var contentBuilder strings.Builder
|
||||||
|
|
||||||
|
if systemPrompt != "" {
|
||||||
|
contentBuilder.WriteString("--- SYSTEM PROMPT ---\n")
|
||||||
|
contentBuilder.WriteString(systemPrompt)
|
||||||
|
contentBuilder.WriteString("\n--- END SYSTEM PROMPT ---\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
contentBuilder.WriteString(content)
|
||||||
|
finalContent := contentBuilder.String()
|
||||||
|
|
||||||
|
// CRITICAL: Kiro API requires content to be non-empty
|
||||||
|
if strings.TrimSpace(finalContent) == "" {
|
||||||
|
if len(toolResults) > 0 {
|
||||||
|
finalContent = "Tool results provided."
|
||||||
|
} else {
|
||||||
|
finalContent = "Continue"
|
||||||
|
}
|
||||||
|
log.Debugf("kiro: content was empty, using default: %s", finalContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
return finalContent
|
||||||
|
}
|
||||||
|
|
||||||
|
// deduplicateToolResults removes duplicate tool results
|
||||||
|
func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult {
|
||||||
|
if len(toolResults) == 0 {
|
||||||
|
return toolResults
|
||||||
|
}
|
||||||
|
|
||||||
|
seenIDs := make(map[string]bool)
|
||||||
|
unique := make([]KiroToolResult, 0, len(toolResults))
|
||||||
|
for _, tr := range toolResults {
|
||||||
|
if !seenIDs[tr.ToolUseID] {
|
||||||
|
seenIDs[tr.ToolUseID] = true
|
||||||
|
unique = append(unique, tr)
|
||||||
|
} else {
|
||||||
|
log.Debugf("kiro: skipping duplicate toolResult in currentMessage: %s", tr.ToolUseID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return unique
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractClaudeToolChoiceHint extracts tool_choice from Claude request and returns a system prompt hint.
|
||||||
|
// Claude tool_choice values:
|
||||||
|
// - {"type": "auto"}: Model decides (default, no hint needed)
|
||||||
|
// - {"type": "any"}: Must use at least one tool
|
||||||
|
// - {"type": "tool", "name": "..."}: Must use specific tool
|
||||||
|
func extractClaudeToolChoiceHint(claudeBody []byte) string {
|
||||||
|
toolChoice := gjson.GetBytes(claudeBody, "tool_choice")
|
||||||
|
if !toolChoice.Exists() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
toolChoiceType := toolChoice.Get("type").String()
|
||||||
|
switch toolChoiceType {
|
||||||
|
case "any":
|
||||||
|
return "[INSTRUCTION: You MUST use at least one of the available tools to respond. Do not respond with text only - always make a tool call.]"
|
||||||
|
case "tool":
|
||||||
|
toolName := toolChoice.Get("name").String()
|
||||||
|
if toolName != "" {
|
||||||
|
return fmt.Sprintf("[INSTRUCTION: You MUST use the tool named '%s' to respond. Do not use any other tool or respond with text only.]", toolName)
|
||||||
|
}
|
||||||
|
case "auto":
|
||||||
|
// Default behavior, no hint needed
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildUserMessageStruct builds a user message and extracts tool results
|
||||||
|
func BuildUserMessageStruct(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
|
||||||
|
content := msg.Get("content")
|
||||||
|
var contentBuilder strings.Builder
|
||||||
|
var toolResults []KiroToolResult
|
||||||
|
var images []KiroImage
|
||||||
|
|
||||||
|
// Track seen toolUseIds to deduplicate
|
||||||
|
seenToolUseIDs := make(map[string]bool)
|
||||||
|
|
||||||
|
if content.IsArray() {
|
||||||
|
for _, part := range content.Array() {
|
||||||
|
partType := part.Get("type").String()
|
||||||
|
switch partType {
|
||||||
|
case "text":
|
||||||
|
contentBuilder.WriteString(part.Get("text").String())
|
||||||
|
case "image":
|
||||||
|
mediaType := part.Get("source.media_type").String()
|
||||||
|
data := part.Get("source.data").String()
|
||||||
|
|
||||||
|
format := ""
|
||||||
|
if idx := strings.LastIndex(mediaType, "/"); idx != -1 {
|
||||||
|
format = mediaType[idx+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if format != "" && data != "" {
|
||||||
|
images = append(images, KiroImage{
|
||||||
|
Format: format,
|
||||||
|
Source: KiroImageSource{
|
||||||
|
Bytes: data,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case "tool_result":
|
||||||
|
toolUseID := part.Get("tool_use_id").String()
|
||||||
|
|
||||||
|
// Skip duplicate toolUseIds
|
||||||
|
if seenToolUseIDs[toolUseID] {
|
||||||
|
log.Debugf("kiro: skipping duplicate tool_result with toolUseId: %s", toolUseID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seenToolUseIDs[toolUseID] = true
|
||||||
|
|
||||||
|
isError := part.Get("is_error").Bool()
|
||||||
|
resultContent := part.Get("content")
|
||||||
|
|
||||||
|
var textContents []KiroTextContent
|
||||||
|
if resultContent.IsArray() {
|
||||||
|
for _, item := range resultContent.Array() {
|
||||||
|
if item.Get("type").String() == "text" {
|
||||||
|
textContents = append(textContents, KiroTextContent{Text: item.Get("text").String()})
|
||||||
|
} else if item.Type == gjson.String {
|
||||||
|
textContents = append(textContents, KiroTextContent{Text: item.String()})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if resultContent.Type == gjson.String {
|
||||||
|
textContents = append(textContents, KiroTextContent{Text: resultContent.String()})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(textContents) == 0 {
|
||||||
|
textContents = append(textContents, KiroTextContent{Text: "Tool use was cancelled by the user"})
|
||||||
|
}
|
||||||
|
|
||||||
|
status := "success"
|
||||||
|
if isError {
|
||||||
|
status = "error"
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResults = append(toolResults, KiroToolResult{
|
||||||
|
ToolUseID: toolUseID,
|
||||||
|
Content: textContents,
|
||||||
|
Status: status,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
contentBuilder.WriteString(content.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
userMsg := KiroUserInputMessage{
|
||||||
|
Content: contentBuilder.String(),
|
||||||
|
ModelID: modelID,
|
||||||
|
Origin: origin,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(images) > 0 {
|
||||||
|
userMsg.Images = images
|
||||||
|
}
|
||||||
|
|
||||||
|
return userMsg, toolResults
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildAssistantMessageStruct builds an assistant message with tool uses
|
||||||
|
func BuildAssistantMessageStruct(msg gjson.Result) KiroAssistantResponseMessage {
|
||||||
|
content := msg.Get("content")
|
||||||
|
var contentBuilder strings.Builder
|
||||||
|
var toolUses []KiroToolUse
|
||||||
|
|
||||||
|
if content.IsArray() {
|
||||||
|
for _, part := range content.Array() {
|
||||||
|
partType := part.Get("type").String()
|
||||||
|
switch partType {
|
||||||
|
case "text":
|
||||||
|
contentBuilder.WriteString(part.Get("text").String())
|
||||||
|
case "tool_use":
|
||||||
|
toolUseID := part.Get("id").String()
|
||||||
|
toolName := part.Get("name").String()
|
||||||
|
toolInput := part.Get("input")
|
||||||
|
|
||||||
|
var inputMap map[string]interface{}
|
||||||
|
if toolInput.IsObject() {
|
||||||
|
inputMap = make(map[string]interface{})
|
||||||
|
toolInput.ForEach(func(key, value gjson.Result) bool {
|
||||||
|
inputMap[key.String()] = value.Value()
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
toolUses = append(toolUses, KiroToolUse{
|
||||||
|
ToolUseID: toolUseID,
|
||||||
|
Name: toolName,
|
||||||
|
Input: inputMap,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
contentBuilder.WriteString(content.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return KiroAssistantResponseMessage{
|
||||||
|
Content: contentBuilder.String(),
|
||||||
|
ToolUses: toolUses,
|
||||||
|
}
|
||||||
|
}
|
||||||
184
internal/translator/kiro/claude/kiro_claude_response.go
Normal file
184
internal/translator/kiro/claude/kiro_claude_response.go
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
// Package claude provides response translation functionality for Kiro API to Claude format.
|
||||||
|
// This package handles the conversion of Kiro API responses into Claude-compatible format,
|
||||||
|
// including support for thinking blocks and tool use.
|
||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Local references to kirocommon constants for thinking block parsing
|
||||||
|
var (
|
||||||
|
thinkingStartTag = kirocommon.ThinkingStartTag
|
||||||
|
thinkingEndTag = kirocommon.ThinkingEndTag
|
||||||
|
)
|
||||||
|
|
||||||
|
// BuildClaudeResponse constructs a Claude-compatible response.
|
||||||
|
// Supports tool_use blocks when tools are present in the response.
|
||||||
|
// Supports thinking blocks - parses <thinking> tags and converts to Claude thinking content blocks.
|
||||||
|
// stopReason is passed from upstream; fallback logic applied if empty.
|
||||||
|
func BuildClaudeResponse(content string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte {
|
||||||
|
var contentBlocks []map[string]interface{}
|
||||||
|
|
||||||
|
// Extract thinking blocks and text from content
|
||||||
|
if content != "" {
|
||||||
|
blocks := ExtractThinkingFromContent(content)
|
||||||
|
contentBlocks = append(contentBlocks, blocks...)
|
||||||
|
|
||||||
|
// Log if thinking blocks were extracted
|
||||||
|
for _, block := range blocks {
|
||||||
|
if block["type"] == "thinking" {
|
||||||
|
thinkingContent := block["thinking"].(string)
|
||||||
|
log.Infof("kiro: buildClaudeResponse extracted thinking block (len: %d)", len(thinkingContent))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tool_use blocks
|
||||||
|
for _, toolUse := range toolUses {
|
||||||
|
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": toolUse.ToolUseID,
|
||||||
|
"name": toolUse.Name,
|
||||||
|
"input": toolUse.Input,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure at least one content block (Claude API requires non-empty content)
|
||||||
|
if len(contentBlocks) == 0 {
|
||||||
|
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use upstream stopReason; apply fallback logic if not provided
|
||||||
|
if stopReason == "" {
|
||||||
|
stopReason = "end_turn"
|
||||||
|
if len(toolUses) > 0 {
|
||||||
|
stopReason = "tool_use"
|
||||||
|
}
|
||||||
|
log.Debugf("kiro: buildClaudeResponse using fallback stop_reason: %s", stopReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log warning if response was truncated due to max_tokens
|
||||||
|
if stopReason == "max_tokens" {
|
||||||
|
log.Warnf("kiro: response truncated due to max_tokens limit (buildClaudeResponse)")
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"id": "msg_" + uuid.New().String()[:24],
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": model,
|
||||||
|
"content": contentBlocks,
|
||||||
|
"stop_reason": stopReason,
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"input_tokens": usageInfo.InputTokens,
|
||||||
|
"output_tokens": usageInfo.OutputTokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(response)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractThinkingFromContent parses content to extract thinking blocks and text.
|
||||||
|
// Returns a list of content blocks in the order they appear in the content.
|
||||||
|
// Handles interleaved thinking and text blocks correctly.
|
||||||
|
func ExtractThinkingFromContent(content string) []map[string]interface{} {
|
||||||
|
var blocks []map[string]interface{}
|
||||||
|
|
||||||
|
if content == "" {
|
||||||
|
return blocks
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if content contains thinking tags at all
|
||||||
|
if !strings.Contains(content, thinkingStartTag) {
|
||||||
|
// No thinking tags, return as plain text
|
||||||
|
return []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": content,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("kiro: extractThinkingFromContent - found thinking tags in content (len: %d)", len(content))
|
||||||
|
|
||||||
|
remaining := content
|
||||||
|
|
||||||
|
for len(remaining) > 0 {
|
||||||
|
// Look for <thinking> tag
|
||||||
|
startIdx := strings.Index(remaining, thinkingStartTag)
|
||||||
|
|
||||||
|
if startIdx == -1 {
|
||||||
|
// No more thinking tags, add remaining as text
|
||||||
|
if strings.TrimSpace(remaining) != "" {
|
||||||
|
blocks = append(blocks, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": remaining,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add text before thinking tag (if any meaningful content)
|
||||||
|
if startIdx > 0 {
|
||||||
|
textBefore := remaining[:startIdx]
|
||||||
|
if strings.TrimSpace(textBefore) != "" {
|
||||||
|
blocks = append(blocks, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": textBefore,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move past the opening tag
|
||||||
|
remaining = remaining[startIdx+len(thinkingStartTag):]
|
||||||
|
|
||||||
|
// Find closing tag
|
||||||
|
endIdx := strings.Index(remaining, thinkingEndTag)
|
||||||
|
|
||||||
|
if endIdx == -1 {
|
||||||
|
// No closing tag found, treat rest as thinking content (incomplete response)
|
||||||
|
if strings.TrimSpace(remaining) != "" {
|
||||||
|
blocks = append(blocks, map[string]interface{}{
|
||||||
|
"type": "thinking",
|
||||||
|
"thinking": remaining,
|
||||||
|
})
|
||||||
|
log.Warnf("kiro: extractThinkingFromContent - missing closing </thinking> tag")
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract thinking content between tags
|
||||||
|
thinkContent := remaining[:endIdx]
|
||||||
|
if strings.TrimSpace(thinkContent) != "" {
|
||||||
|
blocks = append(blocks, map[string]interface{}{
|
||||||
|
"type": "thinking",
|
||||||
|
"thinking": thinkContent,
|
||||||
|
})
|
||||||
|
log.Debugf("kiro: extractThinkingFromContent - extracted thinking block (len: %d)", len(thinkContent))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move past the closing tag
|
||||||
|
remaining = remaining[endIdx+len(thinkingEndTag):]
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no blocks were created (all whitespace), return empty text block
|
||||||
|
if len(blocks) == 0 {
|
||||||
|
blocks = append(blocks, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return blocks
|
||||||
|
}
|
||||||
176
internal/translator/kiro/claude/kiro_claude_stream.go
Normal file
176
internal/translator/kiro/claude/kiro_claude_stream.go
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
// Package claude provides streaming SSE event building for Claude format.
|
||||||
|
// This package handles the construction of Claude-compatible Server-Sent Events (SSE)
|
||||||
|
// for streaming responses from Kiro API.
|
||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BuildClaudeMessageStartEvent creates the message_start SSE event
|
||||||
|
func BuildClaudeMessageStartEvent(model string, inputTokens int64) []byte {
|
||||||
|
event := map[string]interface{}{
|
||||||
|
"type": "message_start",
|
||||||
|
"message": map[string]interface{}{
|
||||||
|
"id": "msg_" + uuid.New().String()[:24],
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": []interface{}{},
|
||||||
|
"model": model,
|
||||||
|
"stop_reason": nil,
|
||||||
|
"stop_sequence": nil,
|
||||||
|
"usage": map[string]interface{}{"input_tokens": inputTokens, "output_tokens": 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(event)
|
||||||
|
return []byte("event: message_start\ndata: " + string(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildClaudeContentBlockStartEvent creates a content_block_start SSE event
|
||||||
|
func BuildClaudeContentBlockStartEvent(index int, blockType, toolUseID, toolName string) []byte {
|
||||||
|
var contentBlock map[string]interface{}
|
||||||
|
switch blockType {
|
||||||
|
case "tool_use":
|
||||||
|
contentBlock = map[string]interface{}{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": toolUseID,
|
||||||
|
"name": toolName,
|
||||||
|
"input": map[string]interface{}{},
|
||||||
|
}
|
||||||
|
case "thinking":
|
||||||
|
contentBlock = map[string]interface{}{
|
||||||
|
"type": "thinking",
|
||||||
|
"thinking": "",
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
contentBlock = map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": "",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
event := map[string]interface{}{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": index,
|
||||||
|
"content_block": contentBlock,
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(event)
|
||||||
|
return []byte("event: content_block_start\ndata: " + string(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildClaudeStreamEvent creates a text_delta content_block_delta SSE event
|
||||||
|
func BuildClaudeStreamEvent(contentDelta string, index int) []byte {
|
||||||
|
event := map[string]interface{}{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": index,
|
||||||
|
"delta": map[string]interface{}{
|
||||||
|
"type": "text_delta",
|
||||||
|
"text": contentDelta,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(event)
|
||||||
|
return []byte("event: content_block_delta\ndata: " + string(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildClaudeInputJsonDeltaEvent creates an input_json_delta event for tool use streaming
|
||||||
|
func BuildClaudeInputJsonDeltaEvent(partialJSON string, index int) []byte {
|
||||||
|
event := map[string]interface{}{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": index,
|
||||||
|
"delta": map[string]interface{}{
|
||||||
|
"type": "input_json_delta",
|
||||||
|
"partial_json": partialJSON,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(event)
|
||||||
|
return []byte("event: content_block_delta\ndata: " + string(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildClaudeContentBlockStopEvent creates a content_block_stop SSE event
|
||||||
|
func BuildClaudeContentBlockStopEvent(index int) []byte {
|
||||||
|
event := map[string]interface{}{
|
||||||
|
"type": "content_block_stop",
|
||||||
|
"index": index,
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(event)
|
||||||
|
return []byte("event: content_block_stop\ndata: " + string(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildClaudeMessageDeltaEvent creates the message_delta event with stop_reason and usage
|
||||||
|
func BuildClaudeMessageDeltaEvent(stopReason string, usageInfo usage.Detail) []byte {
|
||||||
|
deltaEvent := map[string]interface{}{
|
||||||
|
"type": "message_delta",
|
||||||
|
"delta": map[string]interface{}{
|
||||||
|
"stop_reason": stopReason,
|
||||||
|
"stop_sequence": nil,
|
||||||
|
},
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"input_tokens": usageInfo.InputTokens,
|
||||||
|
"output_tokens": usageInfo.OutputTokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
deltaResult, _ := json.Marshal(deltaEvent)
|
||||||
|
return []byte("event: message_delta\ndata: " + string(deltaResult))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildClaudeMessageStopOnlyEvent creates only the message_stop event
|
||||||
|
func BuildClaudeMessageStopOnlyEvent() []byte {
|
||||||
|
stopEvent := map[string]interface{}{
|
||||||
|
"type": "message_stop",
|
||||||
|
}
|
||||||
|
stopResult, _ := json.Marshal(stopEvent)
|
||||||
|
return []byte("event: message_stop\ndata: " + string(stopResult))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildClaudePingEventWithUsage creates a ping event with embedded usage information.
|
||||||
|
// This is used for real-time usage estimation during streaming.
|
||||||
|
func BuildClaudePingEventWithUsage(inputTokens, outputTokens int64) []byte {
|
||||||
|
event := map[string]interface{}{
|
||||||
|
"type": "ping",
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"input_tokens": inputTokens,
|
||||||
|
"output_tokens": outputTokens,
|
||||||
|
"total_tokens": inputTokens + outputTokens,
|
||||||
|
"estimated": true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(event)
|
||||||
|
return []byte("event: ping\ndata: " + string(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildClaudeThinkingDeltaEvent creates a thinking_delta event for Claude API compatibility.
|
||||||
|
// This is used when streaming thinking content wrapped in <thinking> tags.
|
||||||
|
func BuildClaudeThinkingDeltaEvent(thinkingDelta string, index int) []byte {
|
||||||
|
event := map[string]interface{}{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": index,
|
||||||
|
"delta": map[string]interface{}{
|
||||||
|
"type": "thinking_delta",
|
||||||
|
"thinking": thinkingDelta,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(event)
|
||||||
|
return []byte("event: content_block_delta\ndata: " + string(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PendingTagSuffix detects if the buffer ends with a partial prefix of the given tag.
|
||||||
|
// Returns the length of the partial match (0 if no match).
|
||||||
|
// Based on amq2api implementation for handling cross-chunk tag boundaries.
|
||||||
|
func PendingTagSuffix(buffer, tag string) int {
|
||||||
|
if buffer == "" || tag == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
maxLen := len(buffer)
|
||||||
|
if maxLen > len(tag)-1 {
|
||||||
|
maxLen = len(tag) - 1
|
||||||
|
}
|
||||||
|
for length := maxLen; length > 0; length-- {
|
||||||
|
if len(buffer) >= length && buffer[len(buffer)-length:] == tag[:length] {
|
||||||
|
return length
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
522
internal/translator/kiro/claude/kiro_claude_tools.go
Normal file
522
internal/translator/kiro/claude/kiro_claude_tools.go
Normal file
@@ -0,0 +1,522 @@
|
|||||||
|
// Package claude provides tool calling support for Kiro to Claude translation.
|
||||||
|
// This package handles parsing embedded tool calls, JSON repair, and deduplication.
|
||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ToolUseState tracks the state of an in-progress tool use during streaming.
|
||||||
|
type ToolUseState struct {
|
||||||
|
ToolUseID string
|
||||||
|
Name string
|
||||||
|
InputBuffer strings.Builder
|
||||||
|
IsComplete bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-compiled regex patterns for performance
|
||||||
|
var (
|
||||||
|
// embeddedToolCallPattern matches [Called tool_name with args: {...}] format
|
||||||
|
embeddedToolCallPattern = regexp.MustCompile(`\[Called\s+([A-Za-z0-9_.-]+)\s+with\s+args:\s*`)
|
||||||
|
// trailingCommaPattern matches trailing commas before closing braces/brackets
|
||||||
|
trailingCommaPattern = regexp.MustCompile(`,\s*([}\]])`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseEmbeddedToolCalls extracts [Called tool_name with args: {...}] format from text.
|
||||||
|
// Kiro sometimes embeds tool calls in text content instead of using toolUseEvent.
|
||||||
|
// Returns the cleaned text (with tool calls removed) and extracted tool uses.
|
||||||
|
func ParseEmbeddedToolCalls(text string, processedIDs map[string]bool) (string, []KiroToolUse) {
|
||||||
|
if !strings.Contains(text, "[Called") {
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolUses []KiroToolUse
|
||||||
|
cleanText := text
|
||||||
|
|
||||||
|
// Find all [Called markers
|
||||||
|
matches := embeddedToolCallPattern.FindAllStringSubmatchIndex(text, -1)
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return text, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process matches in reverse order to maintain correct indices
|
||||||
|
for i := len(matches) - 1; i >= 0; i-- {
|
||||||
|
matchStart := matches[i][0]
|
||||||
|
toolNameStart := matches[i][2]
|
||||||
|
toolNameEnd := matches[i][3]
|
||||||
|
|
||||||
|
if toolNameStart < 0 || toolNameEnd < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
toolName := text[toolNameStart:toolNameEnd]
|
||||||
|
|
||||||
|
// Find the JSON object start (after "with args:")
|
||||||
|
jsonStart := matches[i][1]
|
||||||
|
if jsonStart >= len(text) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip whitespace to find the opening brace
|
||||||
|
for jsonStart < len(text) && (text[jsonStart] == ' ' || text[jsonStart] == '\t') {
|
||||||
|
jsonStart++
|
||||||
|
}
|
||||||
|
|
||||||
|
if jsonStart >= len(text) || text[jsonStart] != '{' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find matching closing bracket
|
||||||
|
jsonEnd := findMatchingBracket(text, jsonStart)
|
||||||
|
if jsonEnd < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract JSON and find the closing bracket of [Called ...]
|
||||||
|
jsonStr := text[jsonStart : jsonEnd+1]
|
||||||
|
|
||||||
|
// Find the closing ] after the JSON
|
||||||
|
closingBracket := jsonEnd + 1
|
||||||
|
for closingBracket < len(text) && text[closingBracket] != ']' {
|
||||||
|
closingBracket++
|
||||||
|
}
|
||||||
|
if closingBracket >= len(text) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// End index of the full tool call (closing ']' inclusive)
|
||||||
|
matchEnd := closingBracket + 1
|
||||||
|
|
||||||
|
// Repair and parse JSON
|
||||||
|
repairedJSON := RepairJSON(jsonStr)
|
||||||
|
var inputMap map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(repairedJSON), &inputMap); err != nil {
|
||||||
|
log.Debugf("kiro: failed to parse embedded tool call JSON: %v, raw: %s", err, jsonStr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate unique tool ID
|
||||||
|
toolUseID := "toolu_" + uuid.New().String()[:12]
|
||||||
|
|
||||||
|
// Check for duplicates using name+input as key
|
||||||
|
dedupeKey := toolName + ":" + repairedJSON
|
||||||
|
if processedIDs != nil {
|
||||||
|
if processedIDs[dedupeKey] {
|
||||||
|
log.Debugf("kiro: skipping duplicate embedded tool call: %s", toolName)
|
||||||
|
// Still remove from text even if duplicate
|
||||||
|
if matchStart >= 0 && matchEnd <= len(cleanText) && matchStart <= matchEnd {
|
||||||
|
cleanText = cleanText[:matchStart] + cleanText[matchEnd:]
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
processedIDs[dedupeKey] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
toolUses = append(toolUses, KiroToolUse{
|
||||||
|
ToolUseID: toolUseID,
|
||||||
|
Name: toolName,
|
||||||
|
Input: inputMap,
|
||||||
|
})
|
||||||
|
|
||||||
|
log.Infof("kiro: extracted embedded tool call: %s (ID: %s)", toolName, toolUseID)
|
||||||
|
|
||||||
|
// Remove from clean text (index-based removal to avoid deleting the wrong occurrence)
|
||||||
|
if matchStart >= 0 && matchEnd <= len(cleanText) && matchStart <= matchEnd {
|
||||||
|
cleanText = cleanText[:matchStart] + cleanText[matchEnd:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cleanText, toolUses
|
||||||
|
}
|
||||||
|
|
||||||
|
// findMatchingBracket finds the index of the closing brace/bracket that matches
|
||||||
|
// the opening one at startPos. Handles nested objects and strings correctly.
|
||||||
|
func findMatchingBracket(text string, startPos int) int {
|
||||||
|
if startPos >= len(text) {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
openChar := text[startPos]
|
||||||
|
var closeChar byte
|
||||||
|
switch openChar {
|
||||||
|
case '{':
|
||||||
|
closeChar = '}'
|
||||||
|
case '[':
|
||||||
|
closeChar = ']'
|
||||||
|
default:
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
depth := 1
|
||||||
|
inString := false
|
||||||
|
escapeNext := false
|
||||||
|
|
||||||
|
for i := startPos + 1; i < len(text); i++ {
|
||||||
|
char := text[i]
|
||||||
|
|
||||||
|
if escapeNext {
|
||||||
|
escapeNext = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if char == '\\' && inString {
|
||||||
|
escapeNext = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if char == '"' {
|
||||||
|
inString = !inString
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !inString {
|
||||||
|
if char == openChar {
|
||||||
|
depth++
|
||||||
|
} else if char == closeChar {
|
||||||
|
depth--
|
||||||
|
if depth == 0 {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// RepairJSON attempts to fix common JSON issues that may occur in tool call arguments.
|
||||||
|
// Conservative repair strategy:
|
||||||
|
// 1. First try to parse JSON directly - if valid, return as-is
|
||||||
|
// 2. Only attempt repair if parsing fails
|
||||||
|
// 3. After repair, validate the result - if still invalid, return original
|
||||||
|
func RepairJSON(jsonString string) string {
|
||||||
|
// Handle empty or invalid input
|
||||||
|
if jsonString == "" {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
|
||||||
|
str := strings.TrimSpace(jsonString)
|
||||||
|
if str == "" {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
|
||||||
|
// CONSERVATIVE STRATEGY: First try to parse directly
|
||||||
|
var testParse interface{}
|
||||||
|
if err := json.Unmarshal([]byte(str), &testParse); err == nil {
|
||||||
|
log.Debugf("kiro: repairJSON - JSON is already valid, returning unchanged")
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("kiro: repairJSON - JSON parse failed, attempting repair")
|
||||||
|
originalStr := str
|
||||||
|
|
||||||
|
// First, escape unescaped newlines/tabs within JSON string values
|
||||||
|
str = escapeNewlinesInStrings(str)
|
||||||
|
// Remove trailing commas before closing braces/brackets
|
||||||
|
str = trailingCommaPattern.ReplaceAllString(str, "$1")
|
||||||
|
|
||||||
|
// Calculate bracket balance
|
||||||
|
braceCount := 0
|
||||||
|
bracketCount := 0
|
||||||
|
inString := false
|
||||||
|
escape := false
|
||||||
|
lastValidIndex := -1
|
||||||
|
|
||||||
|
for i := 0; i < len(str); i++ {
|
||||||
|
char := str[i]
|
||||||
|
|
||||||
|
if escape {
|
||||||
|
escape = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if char == '\\' {
|
||||||
|
escape = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if char == '"' {
|
||||||
|
inString = !inString
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if inString {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch char {
|
||||||
|
case '{':
|
||||||
|
braceCount++
|
||||||
|
case '}':
|
||||||
|
braceCount--
|
||||||
|
case '[':
|
||||||
|
bracketCount++
|
||||||
|
case ']':
|
||||||
|
bracketCount--
|
||||||
|
}
|
||||||
|
|
||||||
|
if braceCount >= 0 && bracketCount >= 0 {
|
||||||
|
lastValidIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If brackets are unbalanced, try to repair
|
||||||
|
if braceCount > 0 || bracketCount > 0 {
|
||||||
|
if lastValidIndex > 0 && lastValidIndex < len(str)-1 {
|
||||||
|
truncated := str[:lastValidIndex+1]
|
||||||
|
// Recount brackets after truncation
|
||||||
|
braceCount = 0
|
||||||
|
bracketCount = 0
|
||||||
|
inString = false
|
||||||
|
escape = false
|
||||||
|
for i := 0; i < len(truncated); i++ {
|
||||||
|
char := truncated[i]
|
||||||
|
if escape {
|
||||||
|
escape = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if char == '\\' {
|
||||||
|
escape = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if char == '"' {
|
||||||
|
inString = !inString
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if inString {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch char {
|
||||||
|
case '{':
|
||||||
|
braceCount++
|
||||||
|
case '}':
|
||||||
|
braceCount--
|
||||||
|
case '[':
|
||||||
|
bracketCount++
|
||||||
|
case ']':
|
||||||
|
bracketCount--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
str = truncated
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add missing closing brackets
|
||||||
|
for braceCount > 0 {
|
||||||
|
str += "}"
|
||||||
|
braceCount--
|
||||||
|
}
|
||||||
|
for bracketCount > 0 {
|
||||||
|
str += "]"
|
||||||
|
bracketCount--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate repaired JSON
|
||||||
|
if err := json.Unmarshal([]byte(str), &testParse); err != nil {
|
||||||
|
log.Warnf("kiro: repairJSON - repair failed to produce valid JSON, returning original")
|
||||||
|
return originalStr
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("kiro: repairJSON - successfully repaired JSON")
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
// escapeNewlinesInStrings escapes literal newlines, tabs, and other control characters
|
||||||
|
// that appear inside JSON string values.
|
||||||
|
func escapeNewlinesInStrings(raw string) string {
|
||||||
|
var result strings.Builder
|
||||||
|
result.Grow(len(raw) + 100)
|
||||||
|
|
||||||
|
inString := false
|
||||||
|
escaped := false
|
||||||
|
|
||||||
|
for i := 0; i < len(raw); i++ {
|
||||||
|
c := raw[i]
|
||||||
|
|
||||||
|
if escaped {
|
||||||
|
result.WriteByte(c)
|
||||||
|
escaped = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if c == '\\' && inString {
|
||||||
|
result.WriteByte(c)
|
||||||
|
escaped = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if c == '"' {
|
||||||
|
inString = !inString
|
||||||
|
result.WriteByte(c)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if inString {
|
||||||
|
switch c {
|
||||||
|
case '\n':
|
||||||
|
result.WriteString("\\n")
|
||||||
|
case '\r':
|
||||||
|
result.WriteString("\\r")
|
||||||
|
case '\t':
|
||||||
|
result.WriteString("\\t")
|
||||||
|
default:
|
||||||
|
result.WriteByte(c)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result.WriteByte(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessToolUseEvent handles a toolUseEvent from the Kiro stream.
|
||||||
|
// It accumulates input fragments and emits tool_use blocks when complete.
|
||||||
|
// Returns events to emit and updated state.
|
||||||
|
func ProcessToolUseEvent(event map[string]interface{}, currentToolUse *ToolUseState, processedIDs map[string]bool) ([]KiroToolUse, *ToolUseState) {
|
||||||
|
var toolUses []KiroToolUse
|
||||||
|
|
||||||
|
// Extract from nested toolUseEvent or direct format
|
||||||
|
tu := event
|
||||||
|
if nested, ok := event["toolUseEvent"].(map[string]interface{}); ok {
|
||||||
|
tu = nested
|
||||||
|
}
|
||||||
|
|
||||||
|
toolUseID := kirocommon.GetString(tu, "toolUseId")
|
||||||
|
toolName := kirocommon.GetString(tu, "name")
|
||||||
|
isStop := false
|
||||||
|
if stop, ok := tu["stop"].(bool); ok {
|
||||||
|
isStop = stop
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get input - can be string (fragment) or object (complete)
|
||||||
|
var inputFragment string
|
||||||
|
var inputMap map[string]interface{}
|
||||||
|
|
||||||
|
if inputRaw, ok := tu["input"]; ok {
|
||||||
|
switch v := inputRaw.(type) {
|
||||||
|
case string:
|
||||||
|
inputFragment = v
|
||||||
|
case map[string]interface{}:
|
||||||
|
inputMap = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// New tool use starting
|
||||||
|
if toolUseID != "" && toolName != "" {
|
||||||
|
if currentToolUse != nil && currentToolUse.ToolUseID != toolUseID {
|
||||||
|
log.Warnf("kiro: interleaved tool use detected - new ID %s arrived while %s in progress, completing previous",
|
||||||
|
toolUseID, currentToolUse.ToolUseID)
|
||||||
|
if !processedIDs[currentToolUse.ToolUseID] {
|
||||||
|
incomplete := KiroToolUse{
|
||||||
|
ToolUseID: currentToolUse.ToolUseID,
|
||||||
|
Name: currentToolUse.Name,
|
||||||
|
}
|
||||||
|
if currentToolUse.InputBuffer.Len() > 0 {
|
||||||
|
raw := currentToolUse.InputBuffer.String()
|
||||||
|
repaired := RepairJSON(raw)
|
||||||
|
|
||||||
|
var input map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(repaired), &input); err != nil {
|
||||||
|
log.Warnf("kiro: failed to parse interleaved tool input: %v, raw: %s", err, raw)
|
||||||
|
input = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
incomplete.Input = input
|
||||||
|
}
|
||||||
|
toolUses = append(toolUses, incomplete)
|
||||||
|
processedIDs[currentToolUse.ToolUseID] = true
|
||||||
|
}
|
||||||
|
currentToolUse = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentToolUse == nil {
|
||||||
|
if processedIDs != nil && processedIDs[toolUseID] {
|
||||||
|
log.Debugf("kiro: skipping duplicate toolUseEvent: %s", toolUseID)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
currentToolUse = &ToolUseState{
|
||||||
|
ToolUseID: toolUseID,
|
||||||
|
Name: toolName,
|
||||||
|
}
|
||||||
|
log.Infof("kiro: starting new tool use: %s (ID: %s)", toolName, toolUseID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accumulate input fragments
|
||||||
|
if currentToolUse != nil && inputFragment != "" {
|
||||||
|
currentToolUse.InputBuffer.WriteString(inputFragment)
|
||||||
|
log.Debugf("kiro: accumulated input fragment, total length: %d", currentToolUse.InputBuffer.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
// If complete input object provided directly
|
||||||
|
if currentToolUse != nil && inputMap != nil {
|
||||||
|
inputBytes, _ := json.Marshal(inputMap)
|
||||||
|
currentToolUse.InputBuffer.Reset()
|
||||||
|
currentToolUse.InputBuffer.Write(inputBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool use complete
|
||||||
|
if isStop && currentToolUse != nil {
|
||||||
|
fullInput := currentToolUse.InputBuffer.String()
|
||||||
|
|
||||||
|
// Repair and parse the accumulated JSON
|
||||||
|
repairedJSON := RepairJSON(fullInput)
|
||||||
|
var finalInput map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(repairedJSON), &finalInput); err != nil {
|
||||||
|
log.Warnf("kiro: failed to parse accumulated tool input: %v, raw: %s", err, fullInput)
|
||||||
|
finalInput = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
toolUse := KiroToolUse{
|
||||||
|
ToolUseID: currentToolUse.ToolUseID,
|
||||||
|
Name: currentToolUse.Name,
|
||||||
|
Input: finalInput,
|
||||||
|
}
|
||||||
|
toolUses = append(toolUses, toolUse)
|
||||||
|
|
||||||
|
if processedIDs != nil {
|
||||||
|
processedIDs[currentToolUse.ToolUseID] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("kiro: completed tool use: %s (ID: %s)", currentToolUse.Name, currentToolUse.ToolUseID)
|
||||||
|
return toolUses, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return toolUses, currentToolUse
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeduplicateToolUses removes duplicate tool uses based on toolUseId and content.
|
||||||
|
func DeduplicateToolUses(toolUses []KiroToolUse) []KiroToolUse {
|
||||||
|
seenIDs := make(map[string]bool)
|
||||||
|
seenContent := make(map[string]bool)
|
||||||
|
var unique []KiroToolUse
|
||||||
|
|
||||||
|
for _, tu := range toolUses {
|
||||||
|
if seenIDs[tu.ToolUseID] {
|
||||||
|
log.Debugf("kiro: removing ID-duplicate tool use: %s (name: %s)", tu.ToolUseID, tu.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
inputJSON, _ := json.Marshal(tu.Input)
|
||||||
|
contentKey := tu.Name + ":" + string(inputJSON)
|
||||||
|
|
||||||
|
if seenContent[contentKey] {
|
||||||
|
log.Debugf("kiro: removing content-duplicate tool use: %s (id: %s)", tu.Name, tu.ToolUseID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
seenIDs[tu.ToolUseID] = true
|
||||||
|
seenContent[contentKey] = true
|
||||||
|
unique = append(unique, tu)
|
||||||
|
}
|
||||||
|
|
||||||
|
return unique
|
||||||
|
}
|
||||||
|
|
||||||
75
internal/translator/kiro/common/constants.go
Normal file
75
internal/translator/kiro/common/constants.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
// Package common provides shared constants and utilities for Kiro translator.
|
||||||
|
package common
|
||||||
|
|
||||||
|
const (
|
||||||
|
// KiroMaxToolDescLen is the maximum description length for Kiro API tools.
|
||||||
|
// Kiro API limit is 10240 bytes, leave room for "..."
|
||||||
|
KiroMaxToolDescLen = 10237
|
||||||
|
|
||||||
|
// ThinkingStartTag is the start tag for thinking blocks in responses.
|
||||||
|
ThinkingStartTag = "<thinking>"
|
||||||
|
|
||||||
|
// ThinkingEndTag is the end tag for thinking blocks in responses.
|
||||||
|
ThinkingEndTag = "</thinking>"
|
||||||
|
|
||||||
|
// CodeFenceMarker is the markdown code fence marker.
|
||||||
|
CodeFenceMarker = "```"
|
||||||
|
|
||||||
|
// AltCodeFenceMarker is the alternative markdown code fence marker.
|
||||||
|
AltCodeFenceMarker = "~~~"
|
||||||
|
|
||||||
|
// InlineCodeMarker is the markdown inline code marker (backtick).
|
||||||
|
InlineCodeMarker = "`"
|
||||||
|
|
||||||
|
// KiroAgenticSystemPrompt is injected only for -agentic models to prevent timeouts on large writes.
|
||||||
|
// AWS Kiro API has a 2-3 minute timeout for large file write operations.
|
||||||
|
KiroAgenticSystemPrompt = `
|
||||||
|
# CRITICAL: CHUNKED WRITE PROTOCOL (MANDATORY)
|
||||||
|
|
||||||
|
You MUST follow these rules for ALL file operations. Violation causes server timeouts and task failure.
|
||||||
|
|
||||||
|
## ABSOLUTE LIMITS
|
||||||
|
- **MAXIMUM 350 LINES** per single write/edit operation - NO EXCEPTIONS
|
||||||
|
- **RECOMMENDED 300 LINES** or less for optimal performance
|
||||||
|
- **NEVER** write entire files in one operation if >300 lines
|
||||||
|
|
||||||
|
## MANDATORY CHUNKED WRITE STRATEGY
|
||||||
|
|
||||||
|
### For NEW FILES (>300 lines total):
|
||||||
|
1. FIRST: Write initial chunk (first 250-300 lines) using write_to_file/fsWrite
|
||||||
|
2. THEN: Append remaining content in 250-300 line chunks using file append operations
|
||||||
|
3. REPEAT: Continue appending until complete
|
||||||
|
|
||||||
|
### For EDITING EXISTING FILES:
|
||||||
|
1. Use surgical edits (apply_diff/targeted edits) - change ONLY what's needed
|
||||||
|
2. NEVER rewrite entire files - use incremental modifications
|
||||||
|
3. Split large refactors into multiple small, focused edits
|
||||||
|
|
||||||
|
### For LARGE CODE GENERATION:
|
||||||
|
1. Generate in logical sections (imports, types, functions separately)
|
||||||
|
2. Write each section as a separate operation
|
||||||
|
3. Use append operations for subsequent sections
|
||||||
|
|
||||||
|
## EXAMPLES OF CORRECT BEHAVIOR
|
||||||
|
|
||||||
|
✅ CORRECT: Writing a 600-line file
|
||||||
|
- Operation 1: Write lines 1-300 (initial file creation)
|
||||||
|
- Operation 2: Append lines 301-600
|
||||||
|
|
||||||
|
✅ CORRECT: Editing multiple functions
|
||||||
|
- Operation 1: Edit function A
|
||||||
|
- Operation 2: Edit function B
|
||||||
|
- Operation 3: Edit function C
|
||||||
|
|
||||||
|
❌ WRONG: Writing 500 lines in single operation → TIMEOUT
|
||||||
|
❌ WRONG: Rewriting entire file to change 5 lines → TIMEOUT
|
||||||
|
❌ WRONG: Generating massive code blocks without chunking → TIMEOUT
|
||||||
|
|
||||||
|
## WHY THIS MATTERS
|
||||||
|
- Server has 2-3 minute timeout for operations
|
||||||
|
- Large writes exceed timeout and FAIL completely
|
||||||
|
- Chunked writes are FASTER and more RELIABLE
|
||||||
|
- Failed writes waste time and require retry
|
||||||
|
|
||||||
|
REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation.`
|
||||||
|
)
|
||||||
125
internal/translator/kiro/common/message_merge.go
Normal file
125
internal/translator/kiro/common/message_merge.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
// Package common provides shared utilities for Kiro translators.
|
||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MergeAdjacentMessages merges adjacent messages with the same role.
|
||||||
|
// This reduces API call complexity and improves compatibility.
|
||||||
|
// Based on AIClient-2-API implementation.
|
||||||
|
func MergeAdjacentMessages(messages []gjson.Result) []gjson.Result {
|
||||||
|
if len(messages) <= 1 {
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
var merged []gjson.Result
|
||||||
|
for _, msg := range messages {
|
||||||
|
if len(merged) == 0 {
|
||||||
|
merged = append(merged, msg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
lastMsg := merged[len(merged)-1]
|
||||||
|
currentRole := msg.Get("role").String()
|
||||||
|
lastRole := lastMsg.Get("role").String()
|
||||||
|
|
||||||
|
if currentRole == lastRole {
|
||||||
|
// Merge content from current message into last message
|
||||||
|
mergedContent := mergeMessageContent(lastMsg, msg)
|
||||||
|
// Create a new merged message JSON
|
||||||
|
mergedMsg := createMergedMessage(lastRole, mergedContent)
|
||||||
|
merged[len(merged)-1] = gjson.Parse(mergedMsg)
|
||||||
|
} else {
|
||||||
|
merged = append(merged, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return merged
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeMessageContent merges the content of two messages with the same role.
|
||||||
|
// Handles both string content and array content (with text, tool_use, tool_result blocks).
|
||||||
|
func mergeMessageContent(msg1, msg2 gjson.Result) string {
|
||||||
|
content1 := msg1.Get("content")
|
||||||
|
content2 := msg2.Get("content")
|
||||||
|
|
||||||
|
// Extract content blocks from both messages
|
||||||
|
var blocks1, blocks2 []map[string]interface{}
|
||||||
|
|
||||||
|
if content1.IsArray() {
|
||||||
|
for _, block := range content1.Array() {
|
||||||
|
blocks1 = append(blocks1, blockToMap(block))
|
||||||
|
}
|
||||||
|
} else if content1.Type == gjson.String {
|
||||||
|
blocks1 = append(blocks1, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": content1.String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if content2.IsArray() {
|
||||||
|
for _, block := range content2.Array() {
|
||||||
|
blocks2 = append(blocks2, blockToMap(block))
|
||||||
|
}
|
||||||
|
} else if content2.Type == gjson.String {
|
||||||
|
blocks2 = append(blocks2, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": content2.String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge text blocks if both end/start with text
|
||||||
|
if len(blocks1) > 0 && len(blocks2) > 0 {
|
||||||
|
if blocks1[len(blocks1)-1]["type"] == "text" && blocks2[0]["type"] == "text" {
|
||||||
|
// Merge the last text block of msg1 with the first text block of msg2
|
||||||
|
text1 := blocks1[len(blocks1)-1]["text"].(string)
|
||||||
|
text2 := blocks2[0]["text"].(string)
|
||||||
|
blocks1[len(blocks1)-1]["text"] = text1 + "\n" + text2
|
||||||
|
blocks2 = blocks2[1:] // Remove the merged block from blocks2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine all blocks
|
||||||
|
allBlocks := append(blocks1, blocks2...)
|
||||||
|
|
||||||
|
// Convert to JSON
|
||||||
|
result, _ := json.Marshal(allBlocks)
|
||||||
|
return string(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// blockToMap converts a gjson.Result block to a map[string]interface{}
|
||||||
|
func blockToMap(block gjson.Result) map[string]interface{} {
|
||||||
|
result := make(map[string]interface{})
|
||||||
|
block.ForEach(func(key, value gjson.Result) bool {
|
||||||
|
if value.IsObject() {
|
||||||
|
result[key.String()] = blockToMap(value)
|
||||||
|
} else if value.IsArray() {
|
||||||
|
var arr []interface{}
|
||||||
|
for _, item := range value.Array() {
|
||||||
|
if item.IsObject() {
|
||||||
|
arr = append(arr, blockToMap(item))
|
||||||
|
} else {
|
||||||
|
arr = append(arr, item.Value())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result[key.String()] = arr
|
||||||
|
} else {
|
||||||
|
result[key.String()] = value.Value()
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// createMergedMessage creates a JSON string for a merged message
|
||||||
|
func createMergedMessage(role string, content string) string {
|
||||||
|
msg := map[string]interface{}{
|
||||||
|
"role": role,
|
||||||
|
"content": json.RawMessage(content),
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(msg)
|
||||||
|
return string(result)
|
||||||
|
}
|
||||||
16
internal/translator/kiro/common/utils.go
Normal file
16
internal/translator/kiro/common/utils.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// Package common provides shared constants and utilities for Kiro translator.
|
||||||
|
package common
|
||||||
|
|
||||||
|
// GetString safely extracts a string from a map.
|
||||||
|
// Returns empty string if the key doesn't exist or the value is not a string.
|
||||||
|
func GetString(m map[string]interface{}, key string) string {
|
||||||
|
if v, ok := m[key].(string); ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStringValue is an alias for GetString for backward compatibility.
|
||||||
|
func GetStringValue(m map[string]interface{}, key string) string {
|
||||||
|
return GetString(m, key)
|
||||||
|
}
|
||||||
20
internal/translator/kiro/openai/init.go
Normal file
20
internal/translator/kiro/openai/init.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
// Package openai provides translation between OpenAI Chat Completions and Kiro formats.
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
translator.Register(
|
||||||
|
OpenAI, // source format
|
||||||
|
Kiro, // target format
|
||||||
|
ConvertOpenAIRequestToKiro,
|
||||||
|
interfaces.TranslateResponse{
|
||||||
|
Stream: ConvertKiroStreamToOpenAI,
|
||||||
|
NonStream: ConvertKiroNonStreamToOpenAI,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
369
internal/translator/kiro/openai/kiro_openai.go
Normal file
369
internal/translator/kiro/openai/kiro_openai.go
Normal file
@@ -0,0 +1,369 @@
|
|||||||
|
// Package openai provides translation between OpenAI Chat Completions and Kiro formats.
|
||||||
|
// This package enables direct OpenAI → Kiro translation, bypassing the Claude intermediate layer.
|
||||||
|
//
|
||||||
|
// The Kiro executor generates Claude-compatible SSE format internally, so the streaming response
|
||||||
|
// translation converts from Claude SSE format to OpenAI SSE format.
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConvertKiroStreamToOpenAI converts Kiro streaming response to OpenAI format.
|
||||||
|
// The Kiro executor emits Claude-compatible SSE events, so this function translates
|
||||||
|
// from Claude SSE format to OpenAI SSE format.
|
||||||
|
//
|
||||||
|
// Claude SSE format:
|
||||||
|
// - event: message_start\ndata: {...}
|
||||||
|
// - event: content_block_start\ndata: {...}
|
||||||
|
// - event: content_block_delta\ndata: {...}
|
||||||
|
// - event: content_block_stop\ndata: {...}
|
||||||
|
// - event: message_delta\ndata: {...}
|
||||||
|
// - event: message_stop\ndata: {...}
|
||||||
|
//
|
||||||
|
// OpenAI SSE format:
|
||||||
|
// - data: {"id":"...","object":"chat.completion.chunk",...}
|
||||||
|
// - data: [DONE]
|
||||||
|
func ConvertKiroStreamToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) []string {
|
||||||
|
// Initialize state if needed
|
||||||
|
if *param == nil {
|
||||||
|
*param = NewOpenAIStreamState(model)
|
||||||
|
}
|
||||||
|
state := (*param).(*OpenAIStreamState)
|
||||||
|
|
||||||
|
// Parse the Claude SSE event
|
||||||
|
responseStr := string(rawResponse)
|
||||||
|
|
||||||
|
// Handle raw event format (event: xxx\ndata: {...})
|
||||||
|
var eventType string
|
||||||
|
var eventData string
|
||||||
|
|
||||||
|
if strings.HasPrefix(responseStr, "event:") {
|
||||||
|
// Parse event type and data
|
||||||
|
lines := strings.SplitN(responseStr, "\n", 2)
|
||||||
|
if len(lines) >= 1 {
|
||||||
|
eventType = strings.TrimSpace(strings.TrimPrefix(lines[0], "event:"))
|
||||||
|
}
|
||||||
|
if len(lines) >= 2 && strings.HasPrefix(lines[1], "data:") {
|
||||||
|
eventData = strings.TrimSpace(strings.TrimPrefix(lines[1], "data:"))
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(responseStr, "data:") {
|
||||||
|
// Just data line
|
||||||
|
eventData = strings.TrimSpace(strings.TrimPrefix(responseStr, "data:"))
|
||||||
|
} else {
|
||||||
|
// Try to parse as raw JSON
|
||||||
|
eventData = strings.TrimSpace(responseStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if eventData == "" {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the event data as JSON
|
||||||
|
eventJSON := gjson.Parse(eventData)
|
||||||
|
if !eventJSON.Exists() {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine event type from JSON if not already set
|
||||||
|
if eventType == "" {
|
||||||
|
eventType = eventJSON.Get("type").String()
|
||||||
|
}
|
||||||
|
|
||||||
|
var results []string
|
||||||
|
|
||||||
|
switch eventType {
|
||||||
|
case "message_start":
|
||||||
|
// Send first chunk with role
|
||||||
|
firstChunk := BuildOpenAISSEFirstChunk(state)
|
||||||
|
results = append(results, firstChunk)
|
||||||
|
|
||||||
|
case "content_block_start":
|
||||||
|
// Check block type
|
||||||
|
blockType := eventJSON.Get("content_block.type").String()
|
||||||
|
switch blockType {
|
||||||
|
case "text":
|
||||||
|
// Text block starting - nothing to emit yet
|
||||||
|
case "thinking":
|
||||||
|
// Thinking block starting - nothing to emit yet for OpenAI
|
||||||
|
case "tool_use":
|
||||||
|
// Tool use block starting
|
||||||
|
toolUseID := eventJSON.Get("content_block.id").String()
|
||||||
|
toolName := eventJSON.Get("content_block.name").String()
|
||||||
|
chunk := BuildOpenAISSEToolCallStart(state, toolUseID, toolName)
|
||||||
|
results = append(results, chunk)
|
||||||
|
state.ToolCallIndex++
|
||||||
|
}
|
||||||
|
|
||||||
|
case "content_block_delta":
|
||||||
|
deltaType := eventJSON.Get("delta.type").String()
|
||||||
|
switch deltaType {
|
||||||
|
case "text_delta":
|
||||||
|
textDelta := eventJSON.Get("delta.text").String()
|
||||||
|
if textDelta != "" {
|
||||||
|
chunk := BuildOpenAISSETextDelta(state, textDelta)
|
||||||
|
results = append(results, chunk)
|
||||||
|
}
|
||||||
|
case "thinking_delta":
|
||||||
|
// Convert thinking to reasoning_content for o1-style compatibility
|
||||||
|
thinkingDelta := eventJSON.Get("delta.thinking").String()
|
||||||
|
if thinkingDelta != "" {
|
||||||
|
chunk := BuildOpenAISSEReasoningDelta(state, thinkingDelta)
|
||||||
|
results = append(results, chunk)
|
||||||
|
}
|
||||||
|
case "input_json_delta":
|
||||||
|
// Tool call arguments delta
|
||||||
|
partialJSON := eventJSON.Get("delta.partial_json").String()
|
||||||
|
if partialJSON != "" {
|
||||||
|
// Get the tool index from content block index
|
||||||
|
blockIndex := int(eventJSON.Get("index").Int())
|
||||||
|
chunk := BuildOpenAISSEToolCallArgumentsDelta(state, partialJSON, blockIndex-1) // Adjust for 0-based tool index
|
||||||
|
results = append(results, chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case "content_block_stop":
|
||||||
|
// Content block ended - nothing to emit for OpenAI
|
||||||
|
|
||||||
|
case "message_delta":
|
||||||
|
// Message delta with stop_reason
|
||||||
|
stopReason := eventJSON.Get("delta.stop_reason").String()
|
||||||
|
finishReason := mapKiroStopReasonToOpenAI(stopReason)
|
||||||
|
if finishReason != "" {
|
||||||
|
chunk := BuildOpenAISSEFinish(state, finishReason)
|
||||||
|
results = append(results, chunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract usage if present
|
||||||
|
if eventJSON.Get("usage").Exists() {
|
||||||
|
inputTokens := eventJSON.Get("usage.input_tokens").Int()
|
||||||
|
outputTokens := eventJSON.Get("usage.output_tokens").Int()
|
||||||
|
usageInfo := usage.Detail{
|
||||||
|
InputTokens: inputTokens,
|
||||||
|
OutputTokens: outputTokens,
|
||||||
|
TotalTokens: inputTokens + outputTokens,
|
||||||
|
}
|
||||||
|
chunk := BuildOpenAISSEUsage(state, usageInfo)
|
||||||
|
results = append(results, chunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
case "message_stop":
|
||||||
|
// Final event - do NOT emit [DONE] here
|
||||||
|
// The handler layer (openai_handlers.go) will send [DONE] when the stream closes
|
||||||
|
// Emitting [DONE] here would cause duplicate [DONE] markers
|
||||||
|
|
||||||
|
case "ping":
|
||||||
|
// Ping event with usage - optionally emit usage chunk
|
||||||
|
if eventJSON.Get("usage").Exists() {
|
||||||
|
inputTokens := eventJSON.Get("usage.input_tokens").Int()
|
||||||
|
outputTokens := eventJSON.Get("usage.output_tokens").Int()
|
||||||
|
usageInfo := usage.Detail{
|
||||||
|
InputTokens: inputTokens,
|
||||||
|
OutputTokens: outputTokens,
|
||||||
|
TotalTokens: inputTokens + outputTokens,
|
||||||
|
}
|
||||||
|
chunk := BuildOpenAISSEUsage(state, usageInfo)
|
||||||
|
results = append(results, chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertKiroNonStreamToOpenAI converts Kiro non-streaming response to OpenAI format.
|
||||||
|
// The Kiro executor returns Claude-compatible JSON responses, so this function translates
|
||||||
|
// from Claude format to OpenAI format.
|
||||||
|
func ConvertKiroNonStreamToOpenAI(ctx context.Context, model string, originalRequest, request, rawResponse []byte, param *any) string {
|
||||||
|
// Parse the Claude-format response
|
||||||
|
response := gjson.ParseBytes(rawResponse)
|
||||||
|
|
||||||
|
// Extract content
|
||||||
|
var content string
|
||||||
|
var toolUses []KiroToolUse
|
||||||
|
var stopReason string
|
||||||
|
|
||||||
|
// Get stop_reason
|
||||||
|
stopReason = response.Get("stop_reason").String()
|
||||||
|
|
||||||
|
// Process content blocks
|
||||||
|
contentBlocks := response.Get("content")
|
||||||
|
if contentBlocks.IsArray() {
|
||||||
|
for _, block := range contentBlocks.Array() {
|
||||||
|
blockType := block.Get("type").String()
|
||||||
|
switch blockType {
|
||||||
|
case "text":
|
||||||
|
content += block.Get("text").String()
|
||||||
|
case "thinking":
|
||||||
|
// Skip thinking blocks for OpenAI format (or convert to reasoning_content if needed)
|
||||||
|
case "tool_use":
|
||||||
|
toolUseID := block.Get("id").String()
|
||||||
|
toolName := block.Get("name").String()
|
||||||
|
toolInput := block.Get("input")
|
||||||
|
|
||||||
|
var inputMap map[string]interface{}
|
||||||
|
if toolInput.IsObject() {
|
||||||
|
inputMap = make(map[string]interface{})
|
||||||
|
toolInput.ForEach(func(key, value gjson.Result) bool {
|
||||||
|
inputMap[key.String()] = value.Value()
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
toolUses = append(toolUses, KiroToolUse{
|
||||||
|
ToolUseID: toolUseID,
|
||||||
|
Name: toolName,
|
||||||
|
Input: inputMap,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract usage
|
||||||
|
usageInfo := usage.Detail{
|
||||||
|
InputTokens: response.Get("usage.input_tokens").Int(),
|
||||||
|
OutputTokens: response.Get("usage.output_tokens").Int(),
|
||||||
|
}
|
||||||
|
usageInfo.TotalTokens = usageInfo.InputTokens + usageInfo.OutputTokens
|
||||||
|
|
||||||
|
// Build OpenAI response
|
||||||
|
openaiResponse := BuildOpenAIResponse(content, toolUses, model, usageInfo, stopReason)
|
||||||
|
return string(openaiResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseClaudeEvent parses a Claude SSE event and returns the event type and data
|
||||||
|
func ParseClaudeEvent(rawEvent []byte) (eventType string, eventData []byte) {
|
||||||
|
lines := bytes.Split(rawEvent, []byte("\n"))
|
||||||
|
for _, line := range lines {
|
||||||
|
line = bytes.TrimSpace(line)
|
||||||
|
if bytes.HasPrefix(line, []byte("event:")) {
|
||||||
|
eventType = string(bytes.TrimSpace(bytes.TrimPrefix(line, []byte("event:"))))
|
||||||
|
} else if bytes.HasPrefix(line, []byte("data:")) {
|
||||||
|
eventData = bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:")))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return eventType, eventData
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractThinkingFromContent parses content to extract thinking blocks.
|
||||||
|
// Returns cleaned content (without thinking tags) and whether thinking was found.
|
||||||
|
func ExtractThinkingFromContent(content string) (string, string, bool) {
|
||||||
|
if !strings.Contains(content, kirocommon.ThinkingStartTag) {
|
||||||
|
return content, "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
var cleanedContent strings.Builder
|
||||||
|
var thinkingContent strings.Builder
|
||||||
|
hasThinking := false
|
||||||
|
remaining := content
|
||||||
|
|
||||||
|
for len(remaining) > 0 {
|
||||||
|
startIdx := strings.Index(remaining, kirocommon.ThinkingStartTag)
|
||||||
|
if startIdx == -1 {
|
||||||
|
cleanedContent.WriteString(remaining)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add content before thinking tag
|
||||||
|
cleanedContent.WriteString(remaining[:startIdx])
|
||||||
|
|
||||||
|
// Move past opening tag
|
||||||
|
remaining = remaining[startIdx+len(kirocommon.ThinkingStartTag):]
|
||||||
|
|
||||||
|
// Find closing tag
|
||||||
|
endIdx := strings.Index(remaining, kirocommon.ThinkingEndTag)
|
||||||
|
if endIdx == -1 {
|
||||||
|
// No closing tag - treat rest as thinking
|
||||||
|
thinkingContent.WriteString(remaining)
|
||||||
|
hasThinking = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract thinking content
|
||||||
|
thinkingContent.WriteString(remaining[:endIdx])
|
||||||
|
hasThinking = true
|
||||||
|
remaining = remaining[endIdx+len(kirocommon.ThinkingEndTag):]
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimSpace(cleanedContent.String()), strings.TrimSpace(thinkingContent.String()), hasThinking
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertOpenAIToolsToKiroFormat is a helper that converts OpenAI tools format to Kiro format
|
||||||
|
func ConvertOpenAIToolsToKiroFormat(tools []map[string]interface{}) []KiroToolWrapper {
|
||||||
|
var kiroTools []KiroToolWrapper
|
||||||
|
|
||||||
|
for _, tool := range tools {
|
||||||
|
toolType, _ := tool["type"].(string)
|
||||||
|
if toolType != "function" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fn, ok := tool["function"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
name := kirocommon.GetString(fn, "name")
|
||||||
|
description := kirocommon.GetString(fn, "description")
|
||||||
|
parameters := fn["parameters"]
|
||||||
|
|
||||||
|
if name == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if description == "" {
|
||||||
|
description = "Tool: " + name
|
||||||
|
}
|
||||||
|
|
||||||
|
kiroTools = append(kiroTools, KiroToolWrapper{
|
||||||
|
ToolSpecification: KiroToolSpecification{
|
||||||
|
Name: name,
|
||||||
|
Description: description,
|
||||||
|
InputSchema: KiroInputSchema{JSON: parameters},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return kiroTools
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAIStreamParams holds parameters for OpenAI streaming conversion
|
||||||
|
type OpenAIStreamParams struct {
|
||||||
|
State *OpenAIStreamState
|
||||||
|
ThinkingState *ThinkingTagState
|
||||||
|
ToolCallsEmitted map[string]bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOpenAIStreamParams creates new streaming parameters
|
||||||
|
func NewOpenAIStreamParams(model string) *OpenAIStreamParams {
|
||||||
|
return &OpenAIStreamParams{
|
||||||
|
State: NewOpenAIStreamState(model),
|
||||||
|
ThinkingState: NewThinkingTagState(),
|
||||||
|
ToolCallsEmitted: make(map[string]bool),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertClaudeToolUseToOpenAI converts a Claude tool_use block to OpenAI tool_calls format
|
||||||
|
func ConvertClaudeToolUseToOpenAI(toolUseID, toolName string, input map[string]interface{}) map[string]interface{} {
|
||||||
|
inputJSON, _ := json.Marshal(input)
|
||||||
|
return map[string]interface{}{
|
||||||
|
"id": toolUseID,
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]interface{}{
|
||||||
|
"name": toolName,
|
||||||
|
"arguments": string(inputJSON),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogStreamEvent logs a streaming event for debugging
|
||||||
|
func LogStreamEvent(eventType, data string) {
|
||||||
|
log.Debugf("kiro-openai: stream event type=%s, data_len=%d", eventType, len(data))
|
||||||
|
}
|
||||||
848
internal/translator/kiro/openai/kiro_openai_request.go
Normal file
848
internal/translator/kiro/openai/kiro_openai_request.go
Normal file
@@ -0,0 +1,848 @@
|
|||||||
|
// Package openai provides request translation from OpenAI Chat Completions to Kiro format.
|
||||||
|
// It handles parsing and transforming OpenAI API requests into the Kiro/Amazon Q API format,
|
||||||
|
// extracting model information, system instructions, message contents, and tool declarations.
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
kirocommon "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/kiro/common"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Kiro API request structs - reuse from kiroclaude package structure
|
||||||
|
|
||||||
|
// KiroPayload is the top-level request structure for Kiro API
|
||||||
|
type KiroPayload struct {
|
||||||
|
ConversationState KiroConversationState `json:"conversationState"`
|
||||||
|
ProfileArn string `json:"profileArn,omitempty"`
|
||||||
|
InferenceConfig *KiroInferenceConfig `json:"inferenceConfig,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroInferenceConfig contains inference parameters for the Kiro API.
|
||||||
|
type KiroInferenceConfig struct {
|
||||||
|
MaxTokens int `json:"maxTokens,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"topP,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroConversationState holds the conversation context
|
||||||
|
type KiroConversationState struct {
|
||||||
|
ChatTriggerType string `json:"chatTriggerType"` // Required: "MANUAL"
|
||||||
|
ConversationID string `json:"conversationId"`
|
||||||
|
CurrentMessage KiroCurrentMessage `json:"currentMessage"`
|
||||||
|
History []KiroHistoryMessage `json:"history,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroCurrentMessage wraps the current user message
|
||||||
|
type KiroCurrentMessage struct {
|
||||||
|
UserInputMessage KiroUserInputMessage `json:"userInputMessage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroHistoryMessage represents a message in the conversation history
|
||||||
|
type KiroHistoryMessage struct {
|
||||||
|
UserInputMessage *KiroUserInputMessage `json:"userInputMessage,omitempty"`
|
||||||
|
AssistantResponseMessage *KiroAssistantResponseMessage `json:"assistantResponseMessage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroImage represents an image in Kiro API format
|
||||||
|
type KiroImage struct {
|
||||||
|
Format string `json:"format"`
|
||||||
|
Source KiroImageSource `json:"source"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroImageSource contains the image data
|
||||||
|
type KiroImageSource struct {
|
||||||
|
Bytes string `json:"bytes"` // base64 encoded image data
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroUserInputMessage represents a user message
|
||||||
|
type KiroUserInputMessage struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
ModelID string `json:"modelId"`
|
||||||
|
Origin string `json:"origin"`
|
||||||
|
Images []KiroImage `json:"images,omitempty"`
|
||||||
|
UserInputMessageContext *KiroUserInputMessageContext `json:"userInputMessageContext,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroUserInputMessageContext contains tool-related context
|
||||||
|
type KiroUserInputMessageContext struct {
|
||||||
|
ToolResults []KiroToolResult `json:"toolResults,omitempty"`
|
||||||
|
Tools []KiroToolWrapper `json:"tools,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroToolResult represents a tool execution result
|
||||||
|
type KiroToolResult struct {
|
||||||
|
Content []KiroTextContent `json:"content"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
ToolUseID string `json:"toolUseId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroTextContent represents text content
|
||||||
|
type KiroTextContent struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroToolWrapper wraps a tool specification
|
||||||
|
type KiroToolWrapper struct {
|
||||||
|
ToolSpecification KiroToolSpecification `json:"toolSpecification"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroToolSpecification defines a tool's schema
|
||||||
|
type KiroToolSpecification struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
InputSchema KiroInputSchema `json:"inputSchema"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroInputSchema wraps the JSON schema for tool input
|
||||||
|
type KiroInputSchema struct {
|
||||||
|
JSON interface{} `json:"json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroAssistantResponseMessage represents an assistant message
|
||||||
|
type KiroAssistantResponseMessage struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
ToolUses []KiroToolUse `json:"toolUses,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroToolUse represents a tool invocation by the assistant
|
||||||
|
type KiroToolUse struct {
|
||||||
|
ToolUseID string `json:"toolUseId"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Input map[string]interface{} `json:"input"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertOpenAIRequestToKiro converts an OpenAI Chat Completions request to Kiro format.
|
||||||
|
// This is the main entry point for request translation.
|
||||||
|
// Note: The actual payload building happens in the executor, this just passes through
|
||||||
|
// the OpenAI format which will be converted by BuildKiroPayloadFromOpenAI.
|
||||||
|
func ConvertOpenAIRequestToKiro(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||||
|
// Pass through the OpenAI format - actual conversion happens in BuildKiroPayloadFromOpenAI
|
||||||
|
return inputRawJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildKiroPayloadFromOpenAI constructs the Kiro API request payload from OpenAI format.
|
||||||
|
// Supports tool calling - tools are passed via userInputMessageContext.
|
||||||
|
// origin parameter determines which quota to use: "CLI" for Amazon Q, "AI_EDITOR" for Kiro IDE.
|
||||||
|
// isAgentic parameter enables chunked write optimization prompt for -agentic model variants.
|
||||||
|
// isChatOnly parameter disables tool calling for -chat model variants (pure conversation mode).
|
||||||
|
// Returns the payload and a boolean indicating whether thinking mode was injected.
|
||||||
|
func BuildKiroPayloadFromOpenAI(openaiBody []byte, modelID, profileArn, origin string, isAgentic, isChatOnly bool) ([]byte, bool) {
|
||||||
|
// Extract max_tokens for potential use in inferenceConfig
|
||||||
|
// Handle -1 as "use maximum" (Kiro max output is ~32000 tokens)
|
||||||
|
const kiroMaxOutputTokens = 32000
|
||||||
|
var maxTokens int64
|
||||||
|
if mt := gjson.GetBytes(openaiBody, "max_tokens"); mt.Exists() {
|
||||||
|
maxTokens = mt.Int()
|
||||||
|
if maxTokens == -1 {
|
||||||
|
maxTokens = kiroMaxOutputTokens
|
||||||
|
log.Debugf("kiro-openai: max_tokens=-1 converted to %d", kiroMaxOutputTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract temperature if specified
|
||||||
|
var temperature float64
|
||||||
|
var hasTemperature bool
|
||||||
|
if temp := gjson.GetBytes(openaiBody, "temperature"); temp.Exists() {
|
||||||
|
temperature = temp.Float()
|
||||||
|
hasTemperature = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract top_p if specified
|
||||||
|
var topP float64
|
||||||
|
var hasTopP bool
|
||||||
|
if tp := gjson.GetBytes(openaiBody, "top_p"); tp.Exists() {
|
||||||
|
topP = tp.Float()
|
||||||
|
hasTopP = true
|
||||||
|
log.Debugf("kiro-openai: extracted top_p: %.2f", topP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize origin value for Kiro API compatibility
|
||||||
|
origin = normalizeOrigin(origin)
|
||||||
|
log.Debugf("kiro-openai: normalized origin value: %s", origin)
|
||||||
|
|
||||||
|
messages := gjson.GetBytes(openaiBody, "messages")
|
||||||
|
|
||||||
|
// For chat-only mode, don't include tools
|
||||||
|
var tools gjson.Result
|
||||||
|
if !isChatOnly {
|
||||||
|
tools = gjson.GetBytes(openaiBody, "tools")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract system prompt from messages
|
||||||
|
systemPrompt := extractSystemPromptFromOpenAI(messages)
|
||||||
|
|
||||||
|
// Inject timestamp context
|
||||||
|
timestamp := time.Now().Format("2006-01-02 15:04:05 MST")
|
||||||
|
timestampContext := fmt.Sprintf("[Context: Current time is %s]", timestamp)
|
||||||
|
if systemPrompt != "" {
|
||||||
|
systemPrompt = timestampContext + "\n\n" + systemPrompt
|
||||||
|
} else {
|
||||||
|
systemPrompt = timestampContext
|
||||||
|
}
|
||||||
|
log.Debugf("kiro-openai: injected timestamp context: %s", timestamp)
|
||||||
|
|
||||||
|
// Inject agentic optimization prompt for -agentic model variants
|
||||||
|
if isAgentic {
|
||||||
|
if systemPrompt != "" {
|
||||||
|
systemPrompt += "\n"
|
||||||
|
}
|
||||||
|
systemPrompt += kirocommon.KiroAgenticSystemPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle tool_choice parameter - Kiro doesn't support it natively, so we inject system prompt hints
|
||||||
|
// OpenAI tool_choice values: "none", "auto", "required", or {"type":"function","function":{"name":"..."}}
|
||||||
|
toolChoiceHint := extractToolChoiceHint(openaiBody)
|
||||||
|
if toolChoiceHint != "" {
|
||||||
|
if systemPrompt != "" {
|
||||||
|
systemPrompt += "\n"
|
||||||
|
}
|
||||||
|
systemPrompt += toolChoiceHint
|
||||||
|
log.Debugf("kiro-openai: injected tool_choice hint into system prompt")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle response_format parameter - Kiro doesn't support it natively, so we inject system prompt hints
|
||||||
|
// OpenAI response_format: {"type": "json_object"} or {"type": "json_schema", "json_schema": {...}}
|
||||||
|
responseFormatHint := extractResponseFormatHint(openaiBody)
|
||||||
|
if responseFormatHint != "" {
|
||||||
|
if systemPrompt != "" {
|
||||||
|
systemPrompt += "\n"
|
||||||
|
}
|
||||||
|
systemPrompt += responseFormatHint
|
||||||
|
log.Debugf("kiro-openai: injected response_format hint into system prompt")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for thinking mode and inject thinking hint
|
||||||
|
// Supports OpenAI reasoning_effort parameter and model name hints
|
||||||
|
thinkingEnabled, budgetTokens := checkThinkingModeFromOpenAI(openaiBody)
|
||||||
|
if thinkingEnabled {
|
||||||
|
// Adjust budgetTokens based on max_tokens if not explicitly set by reasoning_effort
|
||||||
|
// Use 50% of max_tokens for thinking, with min 8000 and max 24000
|
||||||
|
if maxTokens > 0 && budgetTokens == 16000 { // 16000 is the default, meaning not explicitly set
|
||||||
|
calculatedBudget := maxTokens / 2
|
||||||
|
if calculatedBudget < 8000 {
|
||||||
|
calculatedBudget = 8000
|
||||||
|
}
|
||||||
|
if calculatedBudget > 24000 {
|
||||||
|
calculatedBudget = 24000
|
||||||
|
}
|
||||||
|
budgetTokens = calculatedBudget
|
||||||
|
log.Debugf("kiro-openai: budgetTokens calculated from max_tokens: %d (max_tokens=%d)", budgetTokens, maxTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
if systemPrompt != "" {
|
||||||
|
systemPrompt += "\n"
|
||||||
|
}
|
||||||
|
dynamicThinkingHint := fmt.Sprintf("<thinking_mode>interleaved</thinking_mode><max_thinking_length>%d</max_thinking_length>", budgetTokens)
|
||||||
|
systemPrompt += dynamicThinkingHint
|
||||||
|
log.Debugf("kiro-openai: injected dynamic thinking hint into system prompt, max_thinking_length: %d", budgetTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert OpenAI tools to Kiro format
|
||||||
|
kiroTools := convertOpenAIToolsToKiro(tools)
|
||||||
|
|
||||||
|
// Process messages and build history
|
||||||
|
history, currentUserMsg, currentToolResults := processOpenAIMessages(messages, modelID, origin)
|
||||||
|
|
||||||
|
// Build content with system prompt
|
||||||
|
if currentUserMsg != nil {
|
||||||
|
currentUserMsg.Content = buildFinalContent(currentUserMsg.Content, systemPrompt, currentToolResults)
|
||||||
|
|
||||||
|
// Deduplicate currentToolResults
|
||||||
|
currentToolResults = deduplicateToolResults(currentToolResults)
|
||||||
|
|
||||||
|
// Build userInputMessageContext with tools and tool results
|
||||||
|
if len(kiroTools) > 0 || len(currentToolResults) > 0 {
|
||||||
|
currentUserMsg.UserInputMessageContext = &KiroUserInputMessageContext{
|
||||||
|
Tools: kiroTools,
|
||||||
|
ToolResults: currentToolResults,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build payload
|
||||||
|
var currentMessage KiroCurrentMessage
|
||||||
|
if currentUserMsg != nil {
|
||||||
|
currentMessage = KiroCurrentMessage{UserInputMessage: *currentUserMsg}
|
||||||
|
} else {
|
||||||
|
fallbackContent := ""
|
||||||
|
if systemPrompt != "" {
|
||||||
|
fallbackContent = "--- SYSTEM PROMPT ---\n" + systemPrompt + "\n--- END SYSTEM PROMPT ---\n"
|
||||||
|
}
|
||||||
|
currentMessage = KiroCurrentMessage{UserInputMessage: KiroUserInputMessage{
|
||||||
|
Content: fallbackContent,
|
||||||
|
ModelID: modelID,
|
||||||
|
Origin: origin,
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build inferenceConfig if we have any inference parameters
|
||||||
|
var inferenceConfig *KiroInferenceConfig
|
||||||
|
if maxTokens > 0 || hasTemperature || hasTopP {
|
||||||
|
inferenceConfig = &KiroInferenceConfig{}
|
||||||
|
if maxTokens > 0 {
|
||||||
|
inferenceConfig.MaxTokens = int(maxTokens)
|
||||||
|
}
|
||||||
|
if hasTemperature {
|
||||||
|
inferenceConfig.Temperature = temperature
|
||||||
|
}
|
||||||
|
if hasTopP {
|
||||||
|
inferenceConfig.TopP = topP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := KiroPayload{
|
||||||
|
ConversationState: KiroConversationState{
|
||||||
|
ChatTriggerType: "MANUAL",
|
||||||
|
ConversationID: uuid.New().String(),
|
||||||
|
CurrentMessage: currentMessage,
|
||||||
|
History: history,
|
||||||
|
},
|
||||||
|
ProfileArn: profileArn,
|
||||||
|
InferenceConfig: inferenceConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("kiro-openai: failed to marshal payload: %v", err)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, thinkingEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeOrigin normalizes origin value for Kiro API compatibility
|
||||||
|
func normalizeOrigin(origin string) string {
|
||||||
|
switch origin {
|
||||||
|
case "KIRO_CLI":
|
||||||
|
return "CLI"
|
||||||
|
case "KIRO_AI_EDITOR":
|
||||||
|
return "AI_EDITOR"
|
||||||
|
case "AMAZON_Q":
|
||||||
|
return "CLI"
|
||||||
|
case "KIRO_IDE":
|
||||||
|
return "AI_EDITOR"
|
||||||
|
default:
|
||||||
|
return origin
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractSystemPromptFromOpenAI extracts system prompt from OpenAI messages
|
||||||
|
func extractSystemPromptFromOpenAI(messages gjson.Result) string {
|
||||||
|
if !messages.IsArray() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var systemParts []string
|
||||||
|
for _, msg := range messages.Array() {
|
||||||
|
if msg.Get("role").String() == "system" {
|
||||||
|
content := msg.Get("content")
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
systemParts = append(systemParts, content.String())
|
||||||
|
} else if content.IsArray() {
|
||||||
|
// Handle array content format
|
||||||
|
for _, part := range content.Array() {
|
||||||
|
if part.Get("type").String() == "text" {
|
||||||
|
systemParts = append(systemParts, part.Get("text").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(systemParts, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// shortenToolNameIfNeeded shortens tool names that exceed 64 characters.
|
||||||
|
// MCP tools often have long names like "mcp__server-name__tool-name".
|
||||||
|
// This preserves the "mcp__" prefix and last segment when possible.
|
||||||
|
func shortenToolNameIfNeeded(name string) string {
|
||||||
|
const limit = 64
|
||||||
|
if len(name) <= limit {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
// For MCP tools, try to preserve prefix and last segment
|
||||||
|
if strings.HasPrefix(name, "mcp__") {
|
||||||
|
idx := strings.LastIndex(name, "__")
|
||||||
|
if idx > 0 {
|
||||||
|
cand := "mcp__" + name[idx+2:]
|
||||||
|
if len(cand) > limit {
|
||||||
|
return cand[:limit]
|
||||||
|
}
|
||||||
|
return cand
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return name[:limit]
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertOpenAIToolsToKiro converts OpenAI tools to Kiro format
|
||||||
|
func convertOpenAIToolsToKiro(tools gjson.Result) []KiroToolWrapper {
|
||||||
|
var kiroTools []KiroToolWrapper
|
||||||
|
if !tools.IsArray() {
|
||||||
|
return kiroTools
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tool := range tools.Array() {
|
||||||
|
// OpenAI tools have type "function" with function definition inside
|
||||||
|
if tool.Get("type").String() != "function" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fn := tool.Get("function")
|
||||||
|
if !fn.Exists() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
name := fn.Get("name").String()
|
||||||
|
description := fn.Get("description").String()
|
||||||
|
parameters := fn.Get("parameters").Value()
|
||||||
|
|
||||||
|
// Shorten tool name if it exceeds 64 characters (common with MCP tools)
|
||||||
|
originalName := name
|
||||||
|
name = shortenToolNameIfNeeded(name)
|
||||||
|
if name != originalName {
|
||||||
|
log.Debugf("kiro-openai: shortened tool name from '%s' to '%s'", originalName, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CRITICAL FIX: Kiro API requires non-empty description
|
||||||
|
if strings.TrimSpace(description) == "" {
|
||||||
|
description = fmt.Sprintf("Tool: %s", name)
|
||||||
|
log.Debugf("kiro-openai: tool '%s' has empty description, using default: %s", name, description)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Truncate long descriptions
|
||||||
|
if len(description) > kirocommon.KiroMaxToolDescLen {
|
||||||
|
truncLen := kirocommon.KiroMaxToolDescLen - 30
|
||||||
|
for truncLen > 0 && !utf8.RuneStart(description[truncLen]) {
|
||||||
|
truncLen--
|
||||||
|
}
|
||||||
|
description = description[:truncLen] + "... (description truncated)"
|
||||||
|
}
|
||||||
|
|
||||||
|
kiroTools = append(kiroTools, KiroToolWrapper{
|
||||||
|
ToolSpecification: KiroToolSpecification{
|
||||||
|
Name: name,
|
||||||
|
Description: description,
|
||||||
|
InputSchema: KiroInputSchema{JSON: parameters},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return kiroTools
|
||||||
|
}
|
||||||
|
|
||||||
|
// processOpenAIMessages processes OpenAI messages and builds Kiro history
|
||||||
|
func processOpenAIMessages(messages gjson.Result, modelID, origin string) ([]KiroHistoryMessage, *KiroUserInputMessage, []KiroToolResult) {
|
||||||
|
var history []KiroHistoryMessage
|
||||||
|
var currentUserMsg *KiroUserInputMessage
|
||||||
|
var currentToolResults []KiroToolResult
|
||||||
|
|
||||||
|
if !messages.IsArray() {
|
||||||
|
return history, currentUserMsg, currentToolResults
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge adjacent messages with the same role
|
||||||
|
messagesArray := kirocommon.MergeAdjacentMessages(messages.Array())
|
||||||
|
|
||||||
|
// Build tool_call_id to name mapping from assistant messages
|
||||||
|
toolCallIDToName := make(map[string]string)
|
||||||
|
for _, msg := range messagesArray {
|
||||||
|
if msg.Get("role").String() == "assistant" {
|
||||||
|
toolCalls := msg.Get("tool_calls")
|
||||||
|
if toolCalls.IsArray() {
|
||||||
|
for _, tc := range toolCalls.Array() {
|
||||||
|
if tc.Get("type").String() == "function" {
|
||||||
|
id := tc.Get("id").String()
|
||||||
|
name := tc.Get("function.name").String()
|
||||||
|
if id != "" && name != "" {
|
||||||
|
toolCallIDToName[id] = name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, msg := range messagesArray {
|
||||||
|
role := msg.Get("role").String()
|
||||||
|
isLastMessage := i == len(messagesArray)-1
|
||||||
|
|
||||||
|
switch role {
|
||||||
|
case "system":
|
||||||
|
// System messages are handled separately via extractSystemPromptFromOpenAI
|
||||||
|
continue
|
||||||
|
|
||||||
|
case "user":
|
||||||
|
userMsg, toolResults := buildUserMessageFromOpenAI(msg, modelID, origin)
|
||||||
|
if isLastMessage {
|
||||||
|
currentUserMsg = &userMsg
|
||||||
|
currentToolResults = toolResults
|
||||||
|
} else {
|
||||||
|
// CRITICAL: Kiro API requires content to be non-empty for history messages
|
||||||
|
if strings.TrimSpace(userMsg.Content) == "" {
|
||||||
|
if len(toolResults) > 0 {
|
||||||
|
userMsg.Content = "Tool results provided."
|
||||||
|
} else {
|
||||||
|
userMsg.Content = "Continue"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// For history messages, embed tool results in context
|
||||||
|
if len(toolResults) > 0 {
|
||||||
|
userMsg.UserInputMessageContext = &KiroUserInputMessageContext{
|
||||||
|
ToolResults: toolResults,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
history = append(history, KiroHistoryMessage{
|
||||||
|
UserInputMessage: &userMsg,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
case "assistant":
|
||||||
|
assistantMsg := buildAssistantMessageFromOpenAI(msg)
|
||||||
|
if isLastMessage {
|
||||||
|
history = append(history, KiroHistoryMessage{
|
||||||
|
AssistantResponseMessage: &assistantMsg,
|
||||||
|
})
|
||||||
|
// Create a "Continue" user message as currentMessage
|
||||||
|
currentUserMsg = &KiroUserInputMessage{
|
||||||
|
Content: "Continue",
|
||||||
|
ModelID: modelID,
|
||||||
|
Origin: origin,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
history = append(history, KiroHistoryMessage{
|
||||||
|
AssistantResponseMessage: &assistantMsg,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
case "tool":
|
||||||
|
// Tool messages in OpenAI format provide results for tool_calls
|
||||||
|
// These are typically followed by user or assistant messages
|
||||||
|
// Process them and merge into the next user message's tool results
|
||||||
|
toolCallID := msg.Get("tool_call_id").String()
|
||||||
|
content := msg.Get("content").String()
|
||||||
|
|
||||||
|
if toolCallID != "" {
|
||||||
|
toolResult := KiroToolResult{
|
||||||
|
ToolUseID: toolCallID,
|
||||||
|
Content: []KiroTextContent{{Text: content}},
|
||||||
|
Status: "success",
|
||||||
|
}
|
||||||
|
// Tool results should be included in the next user message
|
||||||
|
// For now, collect them and they'll be handled when we build the current message
|
||||||
|
currentToolResults = append(currentToolResults, toolResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return history, currentUserMsg, currentToolResults
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildUserMessageFromOpenAI builds a user message from OpenAI format and extracts tool results
|
||||||
|
func buildUserMessageFromOpenAI(msg gjson.Result, modelID, origin string) (KiroUserInputMessage, []KiroToolResult) {
|
||||||
|
content := msg.Get("content")
|
||||||
|
var contentBuilder strings.Builder
|
||||||
|
var toolResults []KiroToolResult
|
||||||
|
var images []KiroImage
|
||||||
|
|
||||||
|
// Track seen toolCallIds to deduplicate
|
||||||
|
seenToolCallIDs := make(map[string]bool)
|
||||||
|
|
||||||
|
if content.IsArray() {
|
||||||
|
for _, part := range content.Array() {
|
||||||
|
partType := part.Get("type").String()
|
||||||
|
switch partType {
|
||||||
|
case "text":
|
||||||
|
contentBuilder.WriteString(part.Get("text").String())
|
||||||
|
case "image_url":
|
||||||
|
imageURL := part.Get("image_url.url").String()
|
||||||
|
if strings.HasPrefix(imageURL, "data:") {
|
||||||
|
// Parse data URL: data:image/png;base64,xxxxx
|
||||||
|
if idx := strings.Index(imageURL, ";base64,"); idx != -1 {
|
||||||
|
mediaType := imageURL[5:idx] // Skip "data:"
|
||||||
|
data := imageURL[idx+8:] // Skip ";base64,"
|
||||||
|
|
||||||
|
format := ""
|
||||||
|
if lastSlash := strings.LastIndex(mediaType, "/"); lastSlash != -1 {
|
||||||
|
format = mediaType[lastSlash+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if format != "" && data != "" {
|
||||||
|
images = append(images, KiroImage{
|
||||||
|
Format: format,
|
||||||
|
Source: KiroImageSource{
|
||||||
|
Bytes: data,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if content.Type == gjson.String {
|
||||||
|
contentBuilder.WriteString(content.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for tool_calls in the message (shouldn't be in user messages, but handle edge cases)
|
||||||
|
_ = seenToolCallIDs // Used for deduplication if needed
|
||||||
|
|
||||||
|
userMsg := KiroUserInputMessage{
|
||||||
|
Content: contentBuilder.String(),
|
||||||
|
ModelID: modelID,
|
||||||
|
Origin: origin,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(images) > 0 {
|
||||||
|
userMsg.Images = images
|
||||||
|
}
|
||||||
|
|
||||||
|
return userMsg, toolResults
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildAssistantMessageFromOpenAI builds an assistant message from OpenAI format
|
||||||
|
func buildAssistantMessageFromOpenAI(msg gjson.Result) KiroAssistantResponseMessage {
|
||||||
|
content := msg.Get("content")
|
||||||
|
var contentBuilder strings.Builder
|
||||||
|
var toolUses []KiroToolUse
|
||||||
|
|
||||||
|
// Handle content
|
||||||
|
if content.Type == gjson.String {
|
||||||
|
contentBuilder.WriteString(content.String())
|
||||||
|
} else if content.IsArray() {
|
||||||
|
for _, part := range content.Array() {
|
||||||
|
if part.Get("type").String() == "text" {
|
||||||
|
contentBuilder.WriteString(part.Get("text").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle tool_calls
|
||||||
|
toolCalls := msg.Get("tool_calls")
|
||||||
|
if toolCalls.IsArray() {
|
||||||
|
for _, tc := range toolCalls.Array() {
|
||||||
|
if tc.Get("type").String() != "function" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
toolUseID := tc.Get("id").String()
|
||||||
|
toolName := tc.Get("function.name").String()
|
||||||
|
toolArgs := tc.Get("function.arguments").String()
|
||||||
|
|
||||||
|
var inputMap map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(toolArgs), &inputMap); err != nil {
|
||||||
|
log.Debugf("kiro-openai: failed to parse tool arguments: %v", err)
|
||||||
|
inputMap = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
toolUses = append(toolUses, KiroToolUse{
|
||||||
|
ToolUseID: toolUseID,
|
||||||
|
Name: toolName,
|
||||||
|
Input: inputMap,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return KiroAssistantResponseMessage{
|
||||||
|
Content: contentBuilder.String(),
|
||||||
|
ToolUses: toolUses,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildFinalContent builds the final content with system prompt
|
||||||
|
func buildFinalContent(content, systemPrompt string, toolResults []KiroToolResult) string {
|
||||||
|
var contentBuilder strings.Builder
|
||||||
|
|
||||||
|
if systemPrompt != "" {
|
||||||
|
contentBuilder.WriteString("--- SYSTEM PROMPT ---\n")
|
||||||
|
contentBuilder.WriteString(systemPrompt)
|
||||||
|
contentBuilder.WriteString("\n--- END SYSTEM PROMPT ---\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
contentBuilder.WriteString(content)
|
||||||
|
finalContent := contentBuilder.String()
|
||||||
|
|
||||||
|
// CRITICAL: Kiro API requires content to be non-empty
|
||||||
|
if strings.TrimSpace(finalContent) == "" {
|
||||||
|
if len(toolResults) > 0 {
|
||||||
|
finalContent = "Tool results provided."
|
||||||
|
} else {
|
||||||
|
finalContent = "Continue"
|
||||||
|
}
|
||||||
|
log.Debugf("kiro-openai: content was empty, using default: %s", finalContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
return finalContent
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkThinkingModeFromOpenAI checks if thinking mode is enabled in the OpenAI request.
|
||||||
|
// Returns (thinkingEnabled, budgetTokens).
|
||||||
|
// Supports:
|
||||||
|
// - reasoning_effort parameter (low/medium/high/auto)
|
||||||
|
// - Model name containing "thinking" or "reason"
|
||||||
|
// - <thinking_mode> tag in system prompt (AMP/Cursor format)
|
||||||
|
func checkThinkingModeFromOpenAI(openaiBody []byte) (bool, int64) {
|
||||||
|
var budgetTokens int64 = 16000 // Default budget
|
||||||
|
|
||||||
|
// Check OpenAI format: reasoning_effort parameter
|
||||||
|
// Valid values: "low", "medium", "high", "auto" (not "none")
|
||||||
|
reasoningEffort := gjson.GetBytes(openaiBody, "reasoning_effort")
|
||||||
|
if reasoningEffort.Exists() {
|
||||||
|
effort := reasoningEffort.String()
|
||||||
|
if effort != "" && effort != "none" {
|
||||||
|
log.Debugf("kiro-openai: thinking mode enabled via reasoning_effort: %s", effort)
|
||||||
|
// Adjust budget based on effort level
|
||||||
|
switch effort {
|
||||||
|
case "low":
|
||||||
|
budgetTokens = 8000
|
||||||
|
case "medium":
|
||||||
|
budgetTokens = 16000
|
||||||
|
case "high":
|
||||||
|
budgetTokens = 32000
|
||||||
|
case "auto":
|
||||||
|
budgetTokens = 16000
|
||||||
|
}
|
||||||
|
return true, budgetTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check AMP/Cursor format: <thinking_mode>interleaved</thinking_mode> in system prompt
|
||||||
|
bodyStr := string(openaiBody)
|
||||||
|
if strings.Contains(bodyStr, "<thinking_mode>") && strings.Contains(bodyStr, "</thinking_mode>") {
|
||||||
|
startTag := "<thinking_mode>"
|
||||||
|
endTag := "</thinking_mode>"
|
||||||
|
startIdx := strings.Index(bodyStr, startTag)
|
||||||
|
if startIdx >= 0 {
|
||||||
|
startIdx += len(startTag)
|
||||||
|
endIdx := strings.Index(bodyStr[startIdx:], endTag)
|
||||||
|
if endIdx >= 0 {
|
||||||
|
thinkingMode := bodyStr[startIdx : startIdx+endIdx]
|
||||||
|
if thinkingMode == "interleaved" || thinkingMode == "enabled" {
|
||||||
|
log.Debugf("kiro-openai: thinking mode enabled via AMP/Cursor format: %s", thinkingMode)
|
||||||
|
// Try to extract max_thinking_length if present
|
||||||
|
if maxLenStart := strings.Index(bodyStr, "<max_thinking_length>"); maxLenStart >= 0 {
|
||||||
|
maxLenStart += len("<max_thinking_length>")
|
||||||
|
if maxLenEnd := strings.Index(bodyStr[maxLenStart:], "</max_thinking_length>"); maxLenEnd >= 0 {
|
||||||
|
maxLenStr := bodyStr[maxLenStart : maxLenStart+maxLenEnd]
|
||||||
|
if parsed, err := fmt.Sscanf(maxLenStr, "%d", &budgetTokens); err == nil && parsed == 1 {
|
||||||
|
log.Debugf("kiro-openai: extracted max_thinking_length: %d", budgetTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true, budgetTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check model name for thinking hints
|
||||||
|
model := gjson.GetBytes(openaiBody, "model").String()
|
||||||
|
modelLower := strings.ToLower(model)
|
||||||
|
if strings.Contains(modelLower, "thinking") || strings.Contains(modelLower, "-reason") {
|
||||||
|
log.Debugf("kiro-openai: thinking mode enabled via model name hint: %s", model)
|
||||||
|
return true, budgetTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("kiro-openai: no thinking mode detected in OpenAI request")
|
||||||
|
return false, budgetTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractToolChoiceHint extracts tool_choice from OpenAI request and returns a system prompt hint.
|
||||||
|
// OpenAI tool_choice values:
|
||||||
|
// - "none": Don't use any tools
|
||||||
|
// - "auto": Model decides (default, no hint needed)
|
||||||
|
// - "required": Must use at least one tool
|
||||||
|
// - {"type":"function","function":{"name":"..."}} : Must use specific tool
|
||||||
|
func extractToolChoiceHint(openaiBody []byte) string {
|
||||||
|
toolChoice := gjson.GetBytes(openaiBody, "tool_choice")
|
||||||
|
if !toolChoice.Exists() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle string values
|
||||||
|
if toolChoice.Type == gjson.String {
|
||||||
|
switch toolChoice.String() {
|
||||||
|
case "none":
|
||||||
|
// Note: When tool_choice is "none", we should ideally not pass tools at all
|
||||||
|
// But since we can't modify tool passing here, we add a strong hint
|
||||||
|
return "[INSTRUCTION: Do NOT use any tools. Respond with text only.]"
|
||||||
|
case "required":
|
||||||
|
return "[INSTRUCTION: You MUST use at least one of the available tools to respond. Do not respond with text only - always make a tool call.]"
|
||||||
|
case "auto":
|
||||||
|
// Default behavior, no hint needed
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle object value: {"type":"function","function":{"name":"..."}}
|
||||||
|
if toolChoice.IsObject() {
|
||||||
|
if toolChoice.Get("type").String() == "function" {
|
||||||
|
toolName := toolChoice.Get("function.name").String()
|
||||||
|
if toolName != "" {
|
||||||
|
return fmt.Sprintf("[INSTRUCTION: You MUST use the tool named '%s' to respond. Do not use any other tool or respond with text only.]", toolName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractResponseFormatHint extracts response_format from OpenAI request and returns a system prompt hint.
|
||||||
|
// OpenAI response_format values:
|
||||||
|
// - {"type": "text"}: Default, no hint needed
|
||||||
|
// - {"type": "json_object"}: Must respond with valid JSON
|
||||||
|
// - {"type": "json_schema", "json_schema": {...}}: Must respond with JSON matching schema
|
||||||
|
func extractResponseFormatHint(openaiBody []byte) string {
|
||||||
|
responseFormat := gjson.GetBytes(openaiBody, "response_format")
|
||||||
|
if !responseFormat.Exists() {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
formatType := responseFormat.Get("type").String()
|
||||||
|
switch formatType {
|
||||||
|
case "json_object":
|
||||||
|
return "[INSTRUCTION: You MUST respond with valid JSON only. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]"
|
||||||
|
case "json_schema":
|
||||||
|
// Extract schema if provided
|
||||||
|
schema := responseFormat.Get("json_schema.schema")
|
||||||
|
if schema.Exists() {
|
||||||
|
schemaStr := schema.Raw
|
||||||
|
// Truncate if too long
|
||||||
|
if len(schemaStr) > 500 {
|
||||||
|
schemaStr = schemaStr[:500] + "..."
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("[INSTRUCTION: You MUST respond with valid JSON that matches this schema: %s. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]", schemaStr)
|
||||||
|
}
|
||||||
|
return "[INSTRUCTION: You MUST respond with valid JSON only. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]"
|
||||||
|
case "text":
|
||||||
|
// Default behavior, no hint needed
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// deduplicateToolResults removes duplicate tool results
|
||||||
|
func deduplicateToolResults(toolResults []KiroToolResult) []KiroToolResult {
|
||||||
|
if len(toolResults) == 0 {
|
||||||
|
return toolResults
|
||||||
|
}
|
||||||
|
|
||||||
|
seenIDs := make(map[string]bool)
|
||||||
|
unique := make([]KiroToolResult, 0, len(toolResults))
|
||||||
|
for _, tr := range toolResults {
|
||||||
|
if !seenIDs[tr.ToolUseID] {
|
||||||
|
seenIDs[tr.ToolUseID] = true
|
||||||
|
unique = append(unique, tr)
|
||||||
|
} else {
|
||||||
|
log.Debugf("kiro-openai: skipping duplicate toolResult: %s", tr.ToolUseID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return unique
|
||||||
|
}
|
||||||
264
internal/translator/kiro/openai/kiro_openai_response.go
Normal file
264
internal/translator/kiro/openai/kiro_openai_response.go
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
// Package openai provides response translation from Kiro to OpenAI format.
|
||||||
|
// This package handles the conversion of Kiro API responses into OpenAI Chat Completions-compatible
|
||||||
|
// JSON format, transforming streaming events and non-streaming responses.
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||||
|
var functionCallIDCounter uint64
|
||||||
|
|
||||||
|
// BuildOpenAIResponse constructs an OpenAI Chat Completions-compatible response.
|
||||||
|
// Supports tool_calls when tools are present in the response.
|
||||||
|
// stopReason is passed from upstream; fallback logic applied if empty.
|
||||||
|
func BuildOpenAIResponse(content string, toolUses []KiroToolUse, model string, usageInfo usage.Detail, stopReason string) []byte {
|
||||||
|
// Build the message object
|
||||||
|
message := map[string]interface{}{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tool_calls if present
|
||||||
|
if len(toolUses) > 0 {
|
||||||
|
var toolCalls []map[string]interface{}
|
||||||
|
for i, tu := range toolUses {
|
||||||
|
inputJSON, _ := json.Marshal(tu.Input)
|
||||||
|
toolCalls = append(toolCalls, map[string]interface{}{
|
||||||
|
"id": tu.ToolUseID,
|
||||||
|
"type": "function",
|
||||||
|
"index": i,
|
||||||
|
"function": map[string]interface{}{
|
||||||
|
"name": tu.Name,
|
||||||
|
"arguments": string(inputJSON),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
message["tool_calls"] = toolCalls
|
||||||
|
// When tool_calls are present, content should be null according to OpenAI spec
|
||||||
|
if content == "" {
|
||||||
|
message["content"] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use upstream stopReason; apply fallback logic if not provided
|
||||||
|
finishReason := mapKiroStopReasonToOpenAI(stopReason)
|
||||||
|
if finishReason == "" {
|
||||||
|
finishReason = "stop"
|
||||||
|
if len(toolUses) > 0 {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
log.Debugf("kiro-openai: buildOpenAIResponse using fallback finish_reason: %s", finishReason)
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"id": "chatcmpl-" + uuid.New().String()[:24],
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": message,
|
||||||
|
"finish_reason": finishReason,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"prompt_tokens": usageInfo.InputTokens,
|
||||||
|
"completion_tokens": usageInfo.OutputTokens,
|
||||||
|
"total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := json.Marshal(response)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapKiroStopReasonToOpenAI converts Kiro/Claude stop_reason to OpenAI finish_reason
|
||||||
|
func mapKiroStopReasonToOpenAI(stopReason string) string {
|
||||||
|
switch stopReason {
|
||||||
|
case "end_turn":
|
||||||
|
return "stop"
|
||||||
|
case "stop_sequence":
|
||||||
|
return "stop"
|
||||||
|
case "tool_use":
|
||||||
|
return "tool_calls"
|
||||||
|
case "max_tokens":
|
||||||
|
return "length"
|
||||||
|
case "content_filtered":
|
||||||
|
return "content_filter"
|
||||||
|
default:
|
||||||
|
return stopReason
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAIStreamChunk constructs an OpenAI Chat Completions streaming chunk.
|
||||||
|
// This is the delta format used in streaming responses.
|
||||||
|
func BuildOpenAIStreamChunk(model string, deltaContent string, deltaToolCalls []map[string]interface{}, finishReason string, index int) []byte {
|
||||||
|
delta := map[string]interface{}{}
|
||||||
|
|
||||||
|
// First chunk should include role
|
||||||
|
if index == 0 && deltaContent == "" && len(deltaToolCalls) == 0 {
|
||||||
|
delta["role"] = "assistant"
|
||||||
|
delta["content"] = ""
|
||||||
|
} else if deltaContent != "" {
|
||||||
|
delta["content"] = deltaContent
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tool_calls delta if present
|
||||||
|
if len(deltaToolCalls) > 0 {
|
||||||
|
delta["tool_calls"] = deltaToolCalls
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := map[string]interface{}{
|
||||||
|
"index": 0,
|
||||||
|
"delta": delta,
|
||||||
|
}
|
||||||
|
|
||||||
|
if finishReason != "" {
|
||||||
|
choice["finish_reason"] = finishReason
|
||||||
|
} else {
|
||||||
|
choice["finish_reason"] = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk := map[string]interface{}{
|
||||||
|
"id": "chatcmpl-" + uuid.New().String()[:12],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]interface{}{choice},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := json.Marshal(chunk)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAIStreamChunkWithToolCallStart creates a stream chunk for tool call start
|
||||||
|
func BuildOpenAIStreamChunkWithToolCallStart(model string, toolUseID, toolName string, toolIndex int) []byte {
|
||||||
|
toolCall := map[string]interface{}{
|
||||||
|
"index": toolIndex,
|
||||||
|
"id": toolUseID,
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]interface{}{
|
||||||
|
"name": toolName,
|
||||||
|
"arguments": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
delta := map[string]interface{}{
|
||||||
|
"tool_calls": []map[string]interface{}{toolCall},
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := map[string]interface{}{
|
||||||
|
"index": 0,
|
||||||
|
"delta": delta,
|
||||||
|
"finish_reason": nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk := map[string]interface{}{
|
||||||
|
"id": "chatcmpl-" + uuid.New().String()[:12],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]interface{}{choice},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := json.Marshal(chunk)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAIStreamChunkWithToolCallDelta creates a stream chunk for tool call arguments delta
|
||||||
|
func BuildOpenAIStreamChunkWithToolCallDelta(model string, argumentsDelta string, toolIndex int) []byte {
|
||||||
|
toolCall := map[string]interface{}{
|
||||||
|
"index": toolIndex,
|
||||||
|
"function": map[string]interface{}{
|
||||||
|
"arguments": argumentsDelta,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
delta := map[string]interface{}{
|
||||||
|
"tool_calls": []map[string]interface{}{toolCall},
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := map[string]interface{}{
|
||||||
|
"index": 0,
|
||||||
|
"delta": delta,
|
||||||
|
"finish_reason": nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk := map[string]interface{}{
|
||||||
|
"id": "chatcmpl-" + uuid.New().String()[:12],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]interface{}{choice},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := json.Marshal(chunk)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAIStreamDoneChunk creates the final [DONE] stream event
|
||||||
|
func BuildOpenAIStreamDoneChunk() []byte {
|
||||||
|
return []byte("data: [DONE]")
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAIStreamFinishChunk creates the final chunk with finish_reason
|
||||||
|
func BuildOpenAIStreamFinishChunk(model string, finishReason string) []byte {
|
||||||
|
choice := map[string]interface{}{
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]interface{}{},
|
||||||
|
"finish_reason": finishReason,
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk := map[string]interface{}{
|
||||||
|
"id": "chatcmpl-" + uuid.New().String()[:12],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]interface{}{choice},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := json.Marshal(chunk)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAIStreamUsageChunk creates a chunk with usage information (optional, for stream_options.include_usage)
|
||||||
|
func BuildOpenAIStreamUsageChunk(model string, usageInfo usage.Detail) []byte {
|
||||||
|
chunk := map[string]interface{}{
|
||||||
|
"id": "chatcmpl-" + uuid.New().String()[:12],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": time.Now().Unix(),
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]interface{}{},
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"prompt_tokens": usageInfo.InputTokens,
|
||||||
|
"completion_tokens": usageInfo.OutputTokens,
|
||||||
|
"total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, _ := json.Marshal(chunk)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateToolCallID generates a unique tool call ID in OpenAI format
|
||||||
|
func GenerateToolCallID(toolName string) string {
|
||||||
|
return fmt.Sprintf("call_%s_%d_%d", toolName[:min(8, len(toolName))], time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
// min returns the minimum of two integers
|
||||||
|
func min(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
212
internal/translator/kiro/openai/kiro_openai_stream.go
Normal file
212
internal/translator/kiro/openai/kiro_openai_stream.go
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
// Package openai provides streaming SSE event building for OpenAI format.
|
||||||
|
// This package handles the construction of OpenAI-compatible Server-Sent Events (SSE)
|
||||||
|
// for streaming responses from Kiro API.
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenAIStreamState tracks the state of streaming response conversion
|
||||||
|
type OpenAIStreamState struct {
|
||||||
|
ChunkIndex int
|
||||||
|
ToolCallIndex int
|
||||||
|
HasSentFirstChunk bool
|
||||||
|
Model string
|
||||||
|
ResponseID string
|
||||||
|
Created int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOpenAIStreamState creates a new stream state for tracking
|
||||||
|
func NewOpenAIStreamState(model string) *OpenAIStreamState {
|
||||||
|
return &OpenAIStreamState{
|
||||||
|
ChunkIndex: 0,
|
||||||
|
ToolCallIndex: 0,
|
||||||
|
HasSentFirstChunk: false,
|
||||||
|
Model: model,
|
||||||
|
ResponseID: "chatcmpl-" + uuid.New().String()[:24],
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatSSEEvent formats a JSON payload for SSE streaming.
|
||||||
|
// Note: This returns raw JSON data without "data:" prefix.
|
||||||
|
// The SSE "data:" prefix is added by the Handler layer (e.g., openai_handlers.go)
|
||||||
|
// to maintain architectural consistency and avoid double-prefix issues.
|
||||||
|
func FormatSSEEvent(data []byte) string {
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAISSETextDelta creates an SSE event for text content delta
|
||||||
|
func BuildOpenAISSETextDelta(state *OpenAIStreamState, textDelta string) string {
|
||||||
|
delta := map[string]interface{}{
|
||||||
|
"content": textDelta,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Include role in first chunk
|
||||||
|
if !state.HasSentFirstChunk {
|
||||||
|
delta["role"] = "assistant"
|
||||||
|
state.HasSentFirstChunk = true
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk := buildBaseChunk(state, delta, nil)
|
||||||
|
result, _ := json.Marshal(chunk)
|
||||||
|
state.ChunkIndex++
|
||||||
|
return FormatSSEEvent(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAISSEToolCallStart creates an SSE event for tool call start
|
||||||
|
func BuildOpenAISSEToolCallStart(state *OpenAIStreamState, toolUseID, toolName string) string {
|
||||||
|
toolCall := map[string]interface{}{
|
||||||
|
"index": state.ToolCallIndex,
|
||||||
|
"id": toolUseID,
|
||||||
|
"type": "function",
|
||||||
|
"function": map[string]interface{}{
|
||||||
|
"name": toolName,
|
||||||
|
"arguments": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
delta := map[string]interface{}{
|
||||||
|
"tool_calls": []map[string]interface{}{toolCall},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Include role in first chunk if not sent yet
|
||||||
|
if !state.HasSentFirstChunk {
|
||||||
|
delta["role"] = "assistant"
|
||||||
|
state.HasSentFirstChunk = true
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk := buildBaseChunk(state, delta, nil)
|
||||||
|
result, _ := json.Marshal(chunk)
|
||||||
|
state.ChunkIndex++
|
||||||
|
return FormatSSEEvent(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAISSEToolCallArgumentsDelta creates an SSE event for tool call arguments delta
|
||||||
|
func BuildOpenAISSEToolCallArgumentsDelta(state *OpenAIStreamState, argumentsDelta string, toolIndex int) string {
|
||||||
|
toolCall := map[string]interface{}{
|
||||||
|
"index": toolIndex,
|
||||||
|
"function": map[string]interface{}{
|
||||||
|
"arguments": argumentsDelta,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
delta := map[string]interface{}{
|
||||||
|
"tool_calls": []map[string]interface{}{toolCall},
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk := buildBaseChunk(state, delta, nil)
|
||||||
|
result, _ := json.Marshal(chunk)
|
||||||
|
state.ChunkIndex++
|
||||||
|
return FormatSSEEvent(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAISSEFinish creates an SSE event with finish_reason
|
||||||
|
func BuildOpenAISSEFinish(state *OpenAIStreamState, finishReason string) string {
|
||||||
|
chunk := buildBaseChunk(state, map[string]interface{}{}, &finishReason)
|
||||||
|
result, _ := json.Marshal(chunk)
|
||||||
|
state.ChunkIndex++
|
||||||
|
return FormatSSEEvent(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAISSEUsage creates an SSE event with usage information
|
||||||
|
func BuildOpenAISSEUsage(state *OpenAIStreamState, usageInfo usage.Detail) string {
|
||||||
|
chunk := map[string]interface{}{
|
||||||
|
"id": state.ResponseID,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": state.Created,
|
||||||
|
"model": state.Model,
|
||||||
|
"choices": []map[string]interface{}{},
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"prompt_tokens": usageInfo.InputTokens,
|
||||||
|
"completion_tokens": usageInfo.OutputTokens,
|
||||||
|
"total_tokens": usageInfo.InputTokens + usageInfo.OutputTokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result, _ := json.Marshal(chunk)
|
||||||
|
return FormatSSEEvent(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAISSEDone creates the final [DONE] SSE event.
|
||||||
|
// Note: This returns raw "[DONE]" without "data:" prefix.
|
||||||
|
// The SSE "data:" prefix is added by the Handler layer (e.g., openai_handlers.go)
|
||||||
|
// to maintain architectural consistency and avoid double-prefix issues.
|
||||||
|
func BuildOpenAISSEDone() string {
|
||||||
|
return "[DONE]"
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildBaseChunk creates a base chunk structure for streaming
|
||||||
|
func buildBaseChunk(state *OpenAIStreamState, delta map[string]interface{}, finishReason *string) map[string]interface{} {
|
||||||
|
choice := map[string]interface{}{
|
||||||
|
"index": 0,
|
||||||
|
"delta": delta,
|
||||||
|
}
|
||||||
|
|
||||||
|
if finishReason != nil {
|
||||||
|
choice["finish_reason"] = *finishReason
|
||||||
|
} else {
|
||||||
|
choice["finish_reason"] = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return map[string]interface{}{
|
||||||
|
"id": state.ResponseID,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": state.Created,
|
||||||
|
"model": state.Model,
|
||||||
|
"choices": []map[string]interface{}{choice},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAISSEReasoningDelta creates an SSE event for reasoning content delta
|
||||||
|
// This is used for o1/o3 style models that expose reasoning tokens
|
||||||
|
func BuildOpenAISSEReasoningDelta(state *OpenAIStreamState, reasoningDelta string) string {
|
||||||
|
delta := map[string]interface{}{
|
||||||
|
"reasoning_content": reasoningDelta,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Include role in first chunk
|
||||||
|
if !state.HasSentFirstChunk {
|
||||||
|
delta["role"] = "assistant"
|
||||||
|
state.HasSentFirstChunk = true
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk := buildBaseChunk(state, delta, nil)
|
||||||
|
result, _ := json.Marshal(chunk)
|
||||||
|
state.ChunkIndex++
|
||||||
|
return FormatSSEEvent(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAISSEFirstChunk creates the first chunk with role only
|
||||||
|
func BuildOpenAISSEFirstChunk(state *OpenAIStreamState) string {
|
||||||
|
delta := map[string]interface{}{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
state.HasSentFirstChunk = true
|
||||||
|
chunk := buildBaseChunk(state, delta, nil)
|
||||||
|
result, _ := json.Marshal(chunk)
|
||||||
|
state.ChunkIndex++
|
||||||
|
return FormatSSEEvent(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ThinkingTagState tracks state for thinking tag detection in streaming
|
||||||
|
type ThinkingTagState struct {
|
||||||
|
InThinkingBlock bool
|
||||||
|
PendingStartChars int
|
||||||
|
PendingEndChars int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewThinkingTagState creates a new thinking tag state
|
||||||
|
func NewThinkingTagState() *ThinkingTagState {
|
||||||
|
return &ThinkingTagState{
|
||||||
|
InThinkingBlock: false,
|
||||||
|
PendingStartChars: 0,
|
||||||
|
PendingEndChars: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
@@ -179,6 +180,9 @@ func (w *Watcher) Start(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
log.Debugf("watching auth directory: %s", w.authDir)
|
log.Debugf("watching auth directory: %s", w.authDir)
|
||||||
|
|
||||||
|
// Watch Kiro IDE token file directory for automatic token updates
|
||||||
|
w.watchKiroIDETokenFile()
|
||||||
|
|
||||||
// Start the event processing goroutine
|
// Start the event processing goroutine
|
||||||
go w.processEvents(ctx)
|
go w.processEvents(ctx)
|
||||||
|
|
||||||
@@ -187,6 +191,31 @@ func (w *Watcher) Start(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// watchKiroIDETokenFile adds the Kiro IDE token file directory to the watcher.
|
||||||
|
// This enables automatic detection of token updates from Kiro IDE.
|
||||||
|
func (w *Watcher) watchKiroIDETokenFile() {
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get home directory for Kiro IDE token watch: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kiro IDE stores tokens in ~/.aws/sso/cache/
|
||||||
|
kiroTokenDir := filepath.Join(homeDir, ".aws", "sso", "cache")
|
||||||
|
|
||||||
|
// Check if directory exists
|
||||||
|
if _, statErr := os.Stat(kiroTokenDir); os.IsNotExist(statErr) {
|
||||||
|
log.Debugf("Kiro IDE token directory does not exist: %s", kiroTokenDir)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if errAdd := w.watcher.Add(kiroTokenDir); errAdd != nil {
|
||||||
|
log.Debugf("failed to watch Kiro IDE token directory %s: %v", kiroTokenDir, errAdd)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debugf("watching Kiro IDE token directory: %s", kiroTokenDir)
|
||||||
|
}
|
||||||
|
|
||||||
// Stop stops the file watcher
|
// Stop stops the file watcher
|
||||||
func (w *Watcher) Stop() error {
|
func (w *Watcher) Stop() error {
|
||||||
w.stopDispatch()
|
w.stopDispatch()
|
||||||
@@ -791,11 +820,21 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
|
|||||||
normalizedAuthDir := w.normalizeAuthPath(w.authDir)
|
normalizedAuthDir := w.normalizeAuthPath(w.authDir)
|
||||||
isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0
|
isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0
|
||||||
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
|
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
|
||||||
isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0
|
isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0
|
||||||
if !isConfigEvent && !isAuthJSON {
|
|
||||||
|
// Check for Kiro IDE token file changes
|
||||||
|
isKiroIDEToken := w.isKiroIDETokenFile(event.Name) && event.Op&authOps != 0
|
||||||
|
|
||||||
|
if !isConfigEvent && !isAuthJSON && !isKiroIDEToken {
|
||||||
// Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise.
|
// Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle Kiro IDE token file changes
|
||||||
|
if isKiroIDEToken {
|
||||||
|
w.handleKiroIDETokenChange(event)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name)
|
log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name)
|
||||||
@@ -857,6 +896,51 @@ func (w *Watcher) scheduleConfigReload() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isKiroIDETokenFile checks if the given path is the Kiro IDE token file.
|
||||||
|
func (w *Watcher) isKiroIDETokenFile(path string) bool {
|
||||||
|
// Check if it's the kiro-auth-token.json file in ~/.aws/sso/cache/
|
||||||
|
// Use filepath.ToSlash to ensure consistent separators across platforms (Windows uses backslashes)
|
||||||
|
normalized := filepath.ToSlash(path)
|
||||||
|
return strings.HasSuffix(normalized, "kiro-auth-token.json") && strings.Contains(normalized, ".aws/sso/cache")
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleKiroIDETokenChange processes changes to the Kiro IDE token file.
|
||||||
|
// When the token file is updated by Kiro IDE, this triggers a reload of Kiro auth.
|
||||||
|
func (w *Watcher) handleKiroIDETokenChange(event fsnotify.Event) {
|
||||||
|
log.Debugf("Kiro IDE token file event detected: %s %s", event.Op.String(), event.Name)
|
||||||
|
|
||||||
|
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
|
||||||
|
// Token file removed - wait briefly for potential atomic replace
|
||||||
|
time.Sleep(replaceCheckDelay)
|
||||||
|
if _, statErr := os.Stat(event.Name); statErr != nil {
|
||||||
|
log.Debugf("Kiro IDE token file removed: %s", event.Name)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to load the updated token
|
||||||
|
tokenData, err := kiroauth.LoadKiroIDEToken()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to load Kiro IDE token after change: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Kiro IDE token file updated, access token refreshed (provider: %s)", tokenData.Provider)
|
||||||
|
|
||||||
|
// Trigger auth state refresh to pick up the new token
|
||||||
|
w.refreshAuthState()
|
||||||
|
|
||||||
|
// Notify callback if set
|
||||||
|
w.clientsMutex.RLock()
|
||||||
|
cfg := w.config
|
||||||
|
w.clientsMutex.RUnlock()
|
||||||
|
|
||||||
|
if w.reloadCallback != nil && cfg != nil {
|
||||||
|
log.Debugf("triggering server update callback after Kiro IDE token change")
|
||||||
|
w.reloadCallback(cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (w *Watcher) reloadConfigIfChanged() {
|
func (w *Watcher) reloadConfigIfChanged() {
|
||||||
data, err := os.ReadFile(w.configPath)
|
data, err := os.ReadFile(w.configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1236,6 +1320,88 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
applyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
|
applyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
|
||||||
out = append(out, a)
|
out = append(out, a)
|
||||||
}
|
}
|
||||||
|
// Kiro (AWS CodeWhisperer) -> synthesize auths
|
||||||
|
var kAuth *kiroauth.KiroAuth
|
||||||
|
if len(cfg.KiroKey) > 0 {
|
||||||
|
kAuth = kiroauth.NewKiroAuth(cfg)
|
||||||
|
}
|
||||||
|
for i := range cfg.KiroKey {
|
||||||
|
kk := cfg.KiroKey[i]
|
||||||
|
var accessToken, profileArn, refreshToken string
|
||||||
|
|
||||||
|
// Try to load from token file first
|
||||||
|
if kk.TokenFile != "" && kAuth != nil {
|
||||||
|
tokenData, err := kAuth.LoadTokenFromFile(kk.TokenFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to load kiro token file %s: %v", kk.TokenFile, err)
|
||||||
|
} else {
|
||||||
|
accessToken = tokenData.AccessToken
|
||||||
|
profileArn = tokenData.ProfileArn
|
||||||
|
refreshToken = tokenData.RefreshToken
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Override with direct config values if provided
|
||||||
|
if kk.AccessToken != "" {
|
||||||
|
accessToken = kk.AccessToken
|
||||||
|
}
|
||||||
|
if kk.ProfileArn != "" {
|
||||||
|
profileArn = kk.ProfileArn
|
||||||
|
}
|
||||||
|
if kk.RefreshToken != "" {
|
||||||
|
refreshToken = kk.RefreshToken
|
||||||
|
}
|
||||||
|
|
||||||
|
if accessToken == "" {
|
||||||
|
log.Warnf("kiro config[%d] missing access_token, skipping", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// profileArn is optional for AWS Builder ID users
|
||||||
|
id, token := idGen.next("kiro:token", accessToken, profileArn)
|
||||||
|
attrs := map[string]string{
|
||||||
|
"source": fmt.Sprintf("config:kiro[%s]", token),
|
||||||
|
"access_token": accessToken,
|
||||||
|
}
|
||||||
|
if profileArn != "" {
|
||||||
|
attrs["profile_arn"] = profileArn
|
||||||
|
}
|
||||||
|
if kk.Region != "" {
|
||||||
|
attrs["region"] = kk.Region
|
||||||
|
}
|
||||||
|
if kk.AgentTaskType != "" {
|
||||||
|
attrs["agent_task_type"] = kk.AgentTaskType
|
||||||
|
}
|
||||||
|
if kk.PreferredEndpoint != "" {
|
||||||
|
attrs["preferred_endpoint"] = kk.PreferredEndpoint
|
||||||
|
} else if cfg.KiroPreferredEndpoint != "" {
|
||||||
|
// Apply global default if not overridden by specific key
|
||||||
|
attrs["preferred_endpoint"] = cfg.KiroPreferredEndpoint
|
||||||
|
}
|
||||||
|
if refreshToken != "" {
|
||||||
|
attrs["refresh_token"] = refreshToken
|
||||||
|
}
|
||||||
|
proxyURL := strings.TrimSpace(kk.ProxyURL)
|
||||||
|
a := &coreauth.Auth{
|
||||||
|
ID: id,
|
||||||
|
Provider: "kiro",
|
||||||
|
Label: "kiro-token",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
ProxyURL: proxyURL,
|
||||||
|
Attributes: attrs,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
if refreshToken != "" {
|
||||||
|
if a.Metadata == nil {
|
||||||
|
a.Metadata = make(map[string]any)
|
||||||
|
}
|
||||||
|
a.Metadata["refresh_token"] = refreshToken
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, a)
|
||||||
|
}
|
||||||
for i := range cfg.OpenAICompatibility {
|
for i := range cfg.OpenAICompatibility {
|
||||||
compat := &cfg.OpenAICompatibility[i]
|
compat := &cfg.OpenAICompatibility[i]
|
||||||
providerName := strings.ToLower(strings.TrimSpace(compat.Name))
|
providerName := strings.ToLower(strings.TrimSpace(compat.Name))
|
||||||
@@ -1342,7 +1508,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Also synthesize auth entries directly from auth files (for OAuth/file-backed providers)
|
// Also synthesize auth entries directly from auth files (for OAuth/file-backed providers)
|
||||||
entries, _ := os.ReadDir(w.authDir)
|
log.Debugf("SnapshotCoreAuths: scanning auth directory: %s", w.authDir)
|
||||||
|
entries, readErr := os.ReadDir(w.authDir)
|
||||||
|
if readErr != nil {
|
||||||
|
log.Errorf("SnapshotCoreAuths: failed to read auth directory %s: %v", w.authDir, readErr)
|
||||||
|
}
|
||||||
|
log.Debugf("SnapshotCoreAuths: found %d entries in auth directory", len(entries))
|
||||||
for _, e := range entries {
|
for _, e := range entries {
|
||||||
if e.IsDir() {
|
if e.IsDir() {
|
||||||
continue
|
continue
|
||||||
@@ -1361,9 +1532,20 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
t, _ := metadata["type"].(string)
|
t, _ := metadata["type"].(string)
|
||||||
|
|
||||||
|
// Detect Kiro auth files by auth_method field (they don't have "type" field)
|
||||||
if t == "" {
|
if t == "" {
|
||||||
|
if authMethod, _ := metadata["auth_method"].(string); authMethod == "builder-id" || authMethod == "social" {
|
||||||
|
t = "kiro"
|
||||||
|
log.Debugf("SnapshotCoreAuths: detected Kiro auth by auth_method: %s", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if t == "" {
|
||||||
|
log.Debugf("SnapshotCoreAuths: skipping file without type: %s", name)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
log.Debugf("SnapshotCoreAuths: processing auth file: %s (type=%s)", name, t)
|
||||||
provider := strings.ToLower(t)
|
provider := strings.ToLower(t)
|
||||||
if provider == "gemini" {
|
if provider == "gemini" {
|
||||||
provider = "gemini-cli"
|
provider = "gemini-cli"
|
||||||
@@ -1372,6 +1554,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
if email, _ := metadata["email"].(string); email != "" {
|
if email, _ := metadata["email"].(string); email != "" {
|
||||||
label = email
|
label = email
|
||||||
}
|
}
|
||||||
|
// For Kiro, use provider field as label if available
|
||||||
|
if provider == "kiro" {
|
||||||
|
if kiroProvider, _ := metadata["provider"].(string); kiroProvider != "" {
|
||||||
|
label = fmt.Sprintf("kiro-%s", strings.ToLower(kiroProvider))
|
||||||
|
}
|
||||||
|
}
|
||||||
// Use relative path under authDir as ID to stay consistent with the file-based token store
|
// Use relative path under authDir as ID to stay consistent with the file-based token store
|
||||||
id := full
|
id := full
|
||||||
if rel, errRel := filepath.Rel(w.authDir, full); errRel == nil && rel != "" {
|
if rel, errRel := filepath.Rel(w.authDir, full); errRel == nil && rel != "" {
|
||||||
@@ -1397,6 +1585,27 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
UpdatedAt: now,
|
UpdatedAt: now,
|
||||||
}
|
}
|
||||||
|
// Set NextRefreshAfter for Kiro auth based on expires_at
|
||||||
|
if provider == "kiro" {
|
||||||
|
if expiresAtStr, ok := metadata["expires_at"].(string); ok && expiresAtStr != "" {
|
||||||
|
if expiresAt, parseErr := time.Parse(time.RFC3339, expiresAtStr); parseErr == nil {
|
||||||
|
// Refresh 30 minutes before expiry
|
||||||
|
a.NextRefreshAfter = expiresAt.Add(-30 * time.Minute)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply global preferred endpoint setting if not present in metadata
|
||||||
|
if cfg.KiroPreferredEndpoint != "" {
|
||||||
|
// Check if already set in metadata (which takes precedence in executor)
|
||||||
|
if _, hasMeta := metadata["preferred_endpoint"]; !hasMeta {
|
||||||
|
if a.Attributes == nil {
|
||||||
|
a.Attributes = make(map[string]string)
|
||||||
|
}
|
||||||
|
a.Attributes["preferred_endpoint"] = cfg.KiroPreferredEndpoint
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
applyAuthExcludedModelsMeta(a, cfg, nil, "oauth")
|
applyAuthExcludedModelsMeta(a, cfg, nil, "oauth")
|
||||||
if provider == "gemini-cli" {
|
if provider == "gemini-cli" {
|
||||||
if virtuals := synthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
if virtuals := synthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
@@ -51,7 +52,25 @@ type BaseAPIHandler struct {
|
|||||||
Cfg *config.SDKConfig
|
Cfg *config.SDKConfig
|
||||||
|
|
||||||
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
|
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
|
||||||
OpenAICompatProviders []string
|
openAICompatProviders []string
|
||||||
|
openAICompatMutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAICompatProviders safely returns a copy of the provider names
|
||||||
|
func (h *BaseAPIHandler) GetOpenAICompatProviders() []string {
|
||||||
|
h.openAICompatMutex.RLock()
|
||||||
|
defer h.openAICompatMutex.RUnlock()
|
||||||
|
result := make([]string, len(h.openAICompatProviders))
|
||||||
|
copy(result, h.openAICompatProviders)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetOpenAICompatProviders safely sets the provider names
|
||||||
|
func (h *BaseAPIHandler) SetOpenAICompatProviders(providers []string) {
|
||||||
|
h.openAICompatMutex.Lock()
|
||||||
|
defer h.openAICompatMutex.Unlock()
|
||||||
|
h.openAICompatProviders = make([]string, len(providers))
|
||||||
|
copy(h.openAICompatProviders, providers)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBaseAPIHandlers creates a new API handlers instance.
|
// NewBaseAPIHandlers creates a new API handlers instance.
|
||||||
@@ -64,11 +83,12 @@ type BaseAPIHandler struct {
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - *BaseAPIHandler: A new API handlers instance
|
// - *BaseAPIHandler: A new API handlers instance
|
||||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
|
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
|
||||||
return &BaseAPIHandler{
|
h := &BaseAPIHandler{
|
||||||
Cfg: cfg,
|
Cfg: cfg,
|
||||||
AuthManager: authManager,
|
AuthManager: authManager,
|
||||||
OpenAICompatProviders: openAICompatProviders,
|
|
||||||
}
|
}
|
||||||
|
h.SetOpenAICompatProviders(openAICompatProviders)
|
||||||
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateClients updates the handlers' client list and configuration.
|
// UpdateClients updates the handlers' client list and configuration.
|
||||||
@@ -398,7 +418,7 @@ func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, mode
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if the provider is a configured openai-compatibility provider
|
// Check if the provider is a configured openai-compatibility provider
|
||||||
for _, pName := range h.OpenAICompatProviders {
|
for _, pName := range h.GetOpenAICompatProviders() {
|
||||||
if pName == providerPart {
|
if pName == providerPart {
|
||||||
return providerPart, modelPart, true
|
return providerPart, modelPart, true
|
||||||
}
|
}
|
||||||
|
|||||||
129
sdk/auth/github_copilot.go
Normal file
129
sdk/auth/github_copilot.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/copilot"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GitHubCopilotAuthenticator implements the OAuth device flow login for GitHub Copilot.
|
||||||
|
type GitHubCopilotAuthenticator struct{}
|
||||||
|
|
||||||
|
// NewGitHubCopilotAuthenticator constructs a new GitHub Copilot authenticator.
|
||||||
|
func NewGitHubCopilotAuthenticator() Authenticator {
|
||||||
|
return &GitHubCopilotAuthenticator{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider returns the provider key for github-copilot.
|
||||||
|
func (GitHubCopilotAuthenticator) Provider() string {
|
||||||
|
return "github-copilot"
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshLead returns nil since GitHub OAuth tokens don't expire in the traditional sense.
|
||||||
|
// The token remains valid until the user revokes it or the Copilot subscription expires.
|
||||||
|
func (GitHubCopilotAuthenticator) RefreshLead() *time.Duration {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login initiates the GitHub device flow authentication for Copilot access.
|
||||||
|
func (a GitHubCopilotAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("cliproxy auth: configuration is required")
|
||||||
|
}
|
||||||
|
if opts == nil {
|
||||||
|
opts = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
authSvc := copilot.NewCopilotAuth(cfg)
|
||||||
|
|
||||||
|
// Start the device flow
|
||||||
|
fmt.Println("Starting GitHub Copilot authentication...")
|
||||||
|
deviceCode, err := authSvc.StartDeviceFlow(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("github-copilot: failed to start device flow: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display the user code and verification URL
|
||||||
|
fmt.Printf("\nTo authenticate, please visit: %s\n", deviceCode.VerificationURI)
|
||||||
|
fmt.Printf("And enter the code: %s\n\n", deviceCode.UserCode)
|
||||||
|
|
||||||
|
// Try to open the browser automatically
|
||||||
|
if !opts.NoBrowser {
|
||||||
|
if browser.IsAvailable() {
|
||||||
|
if errOpen := browser.OpenURL(deviceCode.VerificationURI); errOpen != nil {
|
||||||
|
log.Warnf("Failed to open browser automatically: %v", errOpen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Waiting for GitHub authorization...")
|
||||||
|
fmt.Printf("(This will timeout in %d seconds if not authorized)\n", deviceCode.ExpiresIn)
|
||||||
|
|
||||||
|
// Wait for user authorization
|
||||||
|
authBundle, err := authSvc.WaitForAuthorization(ctx, deviceCode)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := copilot.GetUserFriendlyMessage(err)
|
||||||
|
return nil, fmt.Errorf("github-copilot: %s", errMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the token can get a Copilot API token
|
||||||
|
fmt.Println("Verifying Copilot access...")
|
||||||
|
apiToken, err := authSvc.GetCopilotAPIToken(ctx, authBundle.TokenData.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("github-copilot: failed to verify Copilot access - you may not have an active Copilot subscription: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the token storage
|
||||||
|
tokenStorage := authSvc.CreateTokenStorage(authBundle)
|
||||||
|
|
||||||
|
// Build metadata with token information for the executor
|
||||||
|
metadata := map[string]any{
|
||||||
|
"type": "github-copilot",
|
||||||
|
"username": authBundle.Username,
|
||||||
|
"access_token": authBundle.TokenData.AccessToken,
|
||||||
|
"token_type": authBundle.TokenData.TokenType,
|
||||||
|
"scope": authBundle.TokenData.Scope,
|
||||||
|
"timestamp": time.Now().UnixMilli(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if apiToken.ExpiresAt > 0 {
|
||||||
|
metadata["api_token_expires_at"] = apiToken.ExpiresAt
|
||||||
|
}
|
||||||
|
|
||||||
|
fileName := fmt.Sprintf("github-copilot-%s.json", authBundle.Username)
|
||||||
|
|
||||||
|
fmt.Printf("\nGitHub Copilot authentication successful for user: %s\n", authBundle.Username)
|
||||||
|
|
||||||
|
return &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: a.Provider(),
|
||||||
|
FileName: fileName,
|
||||||
|
Label: authBundle.Username,
|
||||||
|
Storage: tokenStorage,
|
||||||
|
Metadata: metadata,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshGitHubCopilotToken validates and returns the current token status.
|
||||||
|
// GitHub OAuth tokens don't need traditional refresh - we just validate they still work.
|
||||||
|
func RefreshGitHubCopilotToken(ctx context.Context, cfg *config.Config, storage *copilot.CopilotTokenStorage) error {
|
||||||
|
if storage == nil || storage.AccessToken == "" {
|
||||||
|
return fmt.Errorf("no token available")
|
||||||
|
}
|
||||||
|
|
||||||
|
authSvc := copilot.NewCopilotAuth(cfg)
|
||||||
|
|
||||||
|
// Validate the token can still get a Copilot API token
|
||||||
|
_, err := authSvc.GetCopilotAPIToken(ctx, storage.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("token validation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
363
sdk/auth/kiro.go
Normal file
363
sdk/auth/kiro.go
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
kiroauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kiro"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// extractKiroIdentifier extracts a meaningful identifier for file naming.
|
||||||
|
// Returns account name if provided, otherwise profile ARN ID.
|
||||||
|
// All extracted values are sanitized to prevent path injection attacks.
|
||||||
|
func extractKiroIdentifier(accountName, profileArn string) string {
|
||||||
|
// Priority 1: Use account name if provided
|
||||||
|
if accountName != "" {
|
||||||
|
return kiroauth.SanitizeEmailForFilename(accountName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Priority 2: Use profile ARN ID part (sanitized to prevent path injection)
|
||||||
|
if profileArn != "" {
|
||||||
|
parts := strings.Split(profileArn, "/")
|
||||||
|
if len(parts) >= 2 {
|
||||||
|
// Sanitize the ARN component to prevent path traversal
|
||||||
|
return kiroauth.SanitizeEmailForFilename(parts[len(parts)-1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: timestamp
|
||||||
|
return fmt.Sprintf("%d", time.Now().UnixNano()%100000)
|
||||||
|
}
|
||||||
|
|
||||||
|
// KiroAuthenticator implements OAuth authentication for Kiro with Google login.
|
||||||
|
type KiroAuthenticator struct{}
|
||||||
|
|
||||||
|
// NewKiroAuthenticator constructs a Kiro authenticator.
|
||||||
|
func NewKiroAuthenticator() *KiroAuthenticator {
|
||||||
|
return &KiroAuthenticator{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider returns the provider key for the authenticator.
|
||||||
|
func (a *KiroAuthenticator) Provider() string {
|
||||||
|
return "kiro"
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshLead indicates how soon before expiry a refresh should be attempted.
|
||||||
|
// Set to 5 minutes to match Antigravity and avoid frequent refresh checks while still ensuring timely token refresh.
|
||||||
|
func (a *KiroAuthenticator) RefreshLead() *time.Duration {
|
||||||
|
d := 5 * time.Minute
|
||||||
|
return &d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login performs OAuth login for Kiro with AWS Builder ID.
|
||||||
|
func (a *KiroAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||||
|
|
||||||
|
// Use AWS Builder ID device code flow
|
||||||
|
tokenData, err := oauth.LoginWithBuilderID(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("login failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse expires_at
|
||||||
|
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
expiresAt = time.Now().Add(1 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract identifier for file naming
|
||||||
|
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
fileName := fmt.Sprintf("kiro-aws-%s.json", idPart)
|
||||||
|
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: "kiro",
|
||||||
|
FileName: fileName,
|
||||||
|
Label: "kiro-aws",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "kiro",
|
||||||
|
"access_token": tokenData.AccessToken,
|
||||||
|
"refresh_token": tokenData.RefreshToken,
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"expires_at": tokenData.ExpiresAt,
|
||||||
|
"auth_method": tokenData.AuthMethod,
|
||||||
|
"provider": tokenData.Provider,
|
||||||
|
"client_id": tokenData.ClientID,
|
||||||
|
"client_secret": tokenData.ClientSecret,
|
||||||
|
"email": tokenData.Email,
|
||||||
|
},
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"source": "aws-builder-id",
|
||||||
|
"email": tokenData.Email,
|
||||||
|
},
|
||||||
|
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||||
|
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenData.Email != "" {
|
||||||
|
fmt.Printf("\n✓ Kiro authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||||
|
} else {
|
||||||
|
fmt.Println("\n✓ Kiro authentication completed successfully!")
|
||||||
|
}
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithGoogle performs OAuth login for Kiro with Google.
|
||||||
|
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||||
|
func (a *KiroAuthenticator) LoginWithGoogle(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||||
|
|
||||||
|
// Use Google OAuth flow with protocol handler
|
||||||
|
tokenData, err := oauth.LoginWithGoogle(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("google login failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse expires_at
|
||||||
|
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
expiresAt = time.Now().Add(1 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract identifier for file naming
|
||||||
|
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
fileName := fmt.Sprintf("kiro-google-%s.json", idPart)
|
||||||
|
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: "kiro",
|
||||||
|
FileName: fileName,
|
||||||
|
Label: "kiro-google",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "kiro",
|
||||||
|
"access_token": tokenData.AccessToken,
|
||||||
|
"refresh_token": tokenData.RefreshToken,
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"expires_at": tokenData.ExpiresAt,
|
||||||
|
"auth_method": tokenData.AuthMethod,
|
||||||
|
"provider": tokenData.Provider,
|
||||||
|
"email": tokenData.Email,
|
||||||
|
},
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"source": "google-oauth",
|
||||||
|
"email": tokenData.Email,
|
||||||
|
},
|
||||||
|
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||||
|
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenData.Email != "" {
|
||||||
|
fmt.Printf("\n✓ Kiro Google authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||||
|
} else {
|
||||||
|
fmt.Println("\n✓ Kiro Google authentication completed successfully!")
|
||||||
|
}
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithGitHub performs OAuth login for Kiro with GitHub.
|
||||||
|
// This uses a custom protocol handler (kiro://) to receive the callback.
|
||||||
|
func (a *KiroAuthenticator) LoginWithGitHub(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("kiro auth: configuration is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||||
|
|
||||||
|
// Use GitHub OAuth flow with protocol handler
|
||||||
|
tokenData, err := oauth.LoginWithGitHub(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("github login failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse expires_at
|
||||||
|
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
expiresAt = time.Now().Add(1 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract identifier for file naming
|
||||||
|
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
fileName := fmt.Sprintf("kiro-github-%s.json", idPart)
|
||||||
|
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: "kiro",
|
||||||
|
FileName: fileName,
|
||||||
|
Label: "kiro-github",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "kiro",
|
||||||
|
"access_token": tokenData.AccessToken,
|
||||||
|
"refresh_token": tokenData.RefreshToken,
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"expires_at": tokenData.ExpiresAt,
|
||||||
|
"auth_method": tokenData.AuthMethod,
|
||||||
|
"provider": tokenData.Provider,
|
||||||
|
"email": tokenData.Email,
|
||||||
|
},
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"source": "github-oauth",
|
||||||
|
"email": tokenData.Email,
|
||||||
|
},
|
||||||
|
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||||
|
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenData.Email != "" {
|
||||||
|
fmt.Printf("\n✓ Kiro GitHub authentication completed successfully! (Account: %s)\n", tokenData.Email)
|
||||||
|
} else {
|
||||||
|
fmt.Println("\n✓ Kiro GitHub authentication completed successfully!")
|
||||||
|
}
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImportFromKiroIDE imports token from Kiro IDE's token file.
|
||||||
|
func (a *KiroAuthenticator) ImportFromKiroIDE(ctx context.Context, cfg *config.Config) (*coreauth.Auth, error) {
|
||||||
|
tokenData, err := kiroauth.LoadKiroIDEToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load Kiro IDE token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse expires_at
|
||||||
|
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
expiresAt = time.Now().Add(1 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract email from JWT if not already set (for imported tokens)
|
||||||
|
if tokenData.Email == "" {
|
||||||
|
tokenData.Email = kiroauth.ExtractEmailFromJWT(tokenData.AccessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract identifier for file naming
|
||||||
|
idPart := extractKiroIdentifier(tokenData.Email, tokenData.ProfileArn)
|
||||||
|
// Sanitize provider to prevent path traversal (defense-in-depth)
|
||||||
|
provider := kiroauth.SanitizeEmailForFilename(strings.ToLower(strings.TrimSpace(tokenData.Provider)))
|
||||||
|
if provider == "" {
|
||||||
|
provider = "imported" // Fallback for legacy tokens without provider
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
fileName := fmt.Sprintf("kiro-%s-%s.json", provider, idPart)
|
||||||
|
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: "kiro",
|
||||||
|
FileName: fileName,
|
||||||
|
Label: fmt.Sprintf("kiro-%s", provider),
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"type": "kiro",
|
||||||
|
"access_token": tokenData.AccessToken,
|
||||||
|
"refresh_token": tokenData.RefreshToken,
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"expires_at": tokenData.ExpiresAt,
|
||||||
|
"auth_method": tokenData.AuthMethod,
|
||||||
|
"provider": tokenData.Provider,
|
||||||
|
"email": tokenData.Email,
|
||||||
|
},
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"profile_arn": tokenData.ProfileArn,
|
||||||
|
"source": "kiro-ide-import",
|
||||||
|
"email": tokenData.Email,
|
||||||
|
},
|
||||||
|
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||||
|
NextRefreshAfter: expiresAt.Add(-5 * time.Minute),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display the email if extracted
|
||||||
|
if tokenData.Email != "" {
|
||||||
|
fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s, Account: %s)\n", tokenData.Provider, tokenData.Email)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("\n✓ Imported Kiro token from IDE (Provider: %s)\n", tokenData.Provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh refreshes an expired Kiro token using AWS SSO OIDC.
|
||||||
|
func (a *KiroAuthenticator) Refresh(ctx context.Context, cfg *config.Config, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
|
if auth == nil || auth.Metadata == nil {
|
||||||
|
return nil, fmt.Errorf("invalid auth record")
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshToken, ok := auth.Metadata["refresh_token"].(string)
|
||||||
|
if !ok || refreshToken == "" {
|
||||||
|
return nil, fmt.Errorf("refresh token not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
clientID, _ := auth.Metadata["client_id"].(string)
|
||||||
|
clientSecret, _ := auth.Metadata["client_secret"].(string)
|
||||||
|
authMethod, _ := auth.Metadata["auth_method"].(string)
|
||||||
|
|
||||||
|
var tokenData *kiroauth.KiroTokenData
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Use SSO OIDC refresh for AWS Builder ID, otherwise use Kiro's OAuth refresh endpoint
|
||||||
|
if clientID != "" && clientSecret != "" && authMethod == "builder-id" {
|
||||||
|
ssoClient := kiroauth.NewSSOOIDCClient(cfg)
|
||||||
|
tokenData, err = ssoClient.RefreshToken(ctx, clientID, clientSecret, refreshToken)
|
||||||
|
} else {
|
||||||
|
// Fallback to Kiro's refresh endpoint (for social auth: Google/GitHub)
|
||||||
|
oauth := kiroauth.NewKiroOAuth(cfg)
|
||||||
|
tokenData, err = oauth.RefreshToken(ctx, refreshToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("token refresh failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse expires_at
|
||||||
|
expiresAt, err := time.Parse(time.RFC3339, tokenData.ExpiresAt)
|
||||||
|
if err != nil {
|
||||||
|
expiresAt = time.Now().Add(1 * time.Hour)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone auth to avoid mutating the input parameter
|
||||||
|
updated := auth.Clone()
|
||||||
|
now := time.Now()
|
||||||
|
updated.UpdatedAt = now
|
||||||
|
updated.LastRefreshedAt = now
|
||||||
|
updated.Metadata["access_token"] = tokenData.AccessToken
|
||||||
|
updated.Metadata["refresh_token"] = tokenData.RefreshToken
|
||||||
|
updated.Metadata["expires_at"] = tokenData.ExpiresAt
|
||||||
|
updated.Metadata["last_refresh"] = now.Format(time.RFC3339) // For double-check optimization
|
||||||
|
// NextRefreshAfter is aligned with RefreshLead (5min)
|
||||||
|
updated.NextRefreshAfter = expiresAt.Add(-5 * time.Minute)
|
||||||
|
|
||||||
|
return updated, nil
|
||||||
|
}
|
||||||
@@ -74,3 +74,16 @@ func (m *Manager) Login(ctx context.Context, provider string, cfg *config.Config
|
|||||||
}
|
}
|
||||||
return record, savedPath, nil
|
return record, savedPath, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveAuth persists an auth record directly without going through the login flow.
|
||||||
|
func (m *Manager) SaveAuth(record *coreauth.Auth, cfg *config.Config) (string, error) {
|
||||||
|
if m.store == nil {
|
||||||
|
return "", fmt.Errorf("no store configured")
|
||||||
|
}
|
||||||
|
if cfg != nil {
|
||||||
|
if dirSetter, ok := m.store.(interface{ SetBaseDir(string) }); ok {
|
||||||
|
dirSetter.SetBaseDir(cfg.AuthDir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m.store.Save(context.Background(), record)
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ func init() {
|
|||||||
registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() })
|
registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() })
|
||||||
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
|
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
|
||||||
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
|
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
|
||||||
|
registerRefreshLead("kiro", func() Authenticator { return NewKiroAuthenticator() })
|
||||||
|
registerRefreshLead("github-copilot", func() Authenticator { return NewGitHubCopilotAuthenticator() })
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerRefreshLead(provider string, factory func() Authenticator) {
|
func registerRefreshLead(provider string, factory func() Authenticator) {
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ type RefreshEvaluator interface {
|
|||||||
const (
|
const (
|
||||||
refreshCheckInterval = 5 * time.Second
|
refreshCheckInterval = 5 * time.Second
|
||||||
refreshPendingBackoff = time.Minute
|
refreshPendingBackoff = time.Minute
|
||||||
refreshFailureBackoff = 5 * time.Minute
|
refreshFailureBackoff = 1 * time.Minute
|
||||||
quotaBackoffBase = time.Second
|
quotaBackoffBase = time.Second
|
||||||
quotaBackoffMax = 30 * time.Minute
|
quotaBackoffMax = 30 * time.Minute
|
||||||
)
|
)
|
||||||
@@ -1498,7 +1498,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
|||||||
updated.Runtime = auth.Runtime
|
updated.Runtime = auth.Runtime
|
||||||
}
|
}
|
||||||
updated.LastRefreshedAt = now
|
updated.LastRefreshedAt = now
|
||||||
updated.NextRefreshAfter = time.Time{}
|
// Preserve NextRefreshAfter set by the Authenticator
|
||||||
|
// If the Authenticator set a reasonable refresh time, it should not be overwritten
|
||||||
|
// If the Authenticator did not set it (zero value), shouldRefresh will use default logic
|
||||||
updated.LastError = nil
|
updated.LastError = nil
|
||||||
updated.UpdatedAt = now
|
updated.UpdatedAt = now
|
||||||
_, _ = m.Update(ctx, updated)
|
_, _ = m.Update(ctx, updated)
|
||||||
|
|||||||
@@ -379,6 +379,10 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
|||||||
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
|
||||||
case "iflow":
|
case "iflow":
|
||||||
s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg))
|
||||||
|
case "kiro":
|
||||||
|
s.coreManager.RegisterExecutor(executor.NewKiroExecutor(s.cfg))
|
||||||
|
case "github-copilot":
|
||||||
|
s.coreManager.RegisterExecutor(executor.NewGitHubCopilotExecutor(s.cfg))
|
||||||
default:
|
default:
|
||||||
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
|
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
|
||||||
if providerKey == "" {
|
if providerKey == "" {
|
||||||
@@ -720,6 +724,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
models = applyExcludedModels(models, excluded)
|
models = applyExcludedModels(models, excluded)
|
||||||
case "iflow":
|
case "iflow":
|
||||||
models = registry.GetIFlowModels()
|
models = registry.GetIFlowModels()
|
||||||
|
case "github-copilot":
|
||||||
|
models = registry.GetGitHubCopilotModels()
|
||||||
|
models = applyExcludedModels(models, excluded)
|
||||||
|
case "kiro":
|
||||||
|
models = registry.GetKiroModels()
|
||||||
models = applyExcludedModels(models, excluded)
|
models = applyExcludedModels(models, excluded)
|
||||||
default:
|
default:
|
||||||
// Handle OpenAI-compatibility providers by name using config
|
// Handle OpenAI-compatibility providers by name using config
|
||||||
|
|||||||
Reference in New Issue
Block a user