diff --git a/README.md b/README.md index 9926f81b..06400ae6 100644 --- a/README.md +++ b/README.md @@ -2,25 +2,31 @@ English | [中文](README_CN.md) -A proxy server that provides an OpenAI/Gemini/Claude compatible API interface for CLI. This allows you to use CLI models with tools and libraries designed for the OpenAI/Gemini/Claude API. +A proxy server that provides OpenAI/Gemini/Claude compatible API interfaces for CLI. + +It now also supports OpenAI Codex (GPT models) via OAuth. + +so you can use local or multi‑account CLI access with OpenAI‑compatible clients and SDKs. ## Features - OpenAI/Gemini/Claude compatible API endpoints for CLI models -- Support for both streaming and non-streaming responses +- OpenAI Codex support (GPT models) via OAuth login +- Streaming and non-streaming responses - Function calling/tools support - Multimodal input support (text and images) -- Multiple account support with load balancing -- Simple CLI authentication flow -- Support for Generative Language API Key -- Support Gemini CLI with multiple account load balancing +- Multiple accounts with round‑robin load balancing (Gemini and OpenAI) +- Simple CLI authentication flows (Gemini and OpenAI) +- Generative Language API Key support +- Gemini CLI multi‑account load balancing ## Installation ### Prerequisites - Go 1.24 or higher -- A Google account with access to CLI models +- A Google account with access to Gemini CLI models (optional) +- An OpenAI account for Codex/GPT access (optional) ### Building from Source @@ -39,17 +45,23 @@ A proxy server that provides an OpenAI/Gemini/Claude compatible API interface fo ### Authentication -Before using the API, you need to authenticate with your Google account: +You can authenticate for Gemini and/or OpenAI. Both can coexist in the same `auth-dir` and will be load balanced. -```bash -./cli-proxy-api --login -``` +- Gemini (Google): + ```bash + ./cli-proxy-api --login + ``` + If you are an old gemini code user, you may need to specify a project ID: + ```bash + ./cli-proxy-api --login --project_id + ``` + The local OAuth callback uses port `8085`. -If you are an old gemini code user, you may need to specify a project ID: - -```bash -./cli-proxy-api --login --project_id -``` +- OpenAI (Codex/GPT via OAuth): + ```bash + ./cli-proxy-api --codex-login + ``` + Options: add `--no-browser` to print the login URL instead of opening a browser. The local OAuth callback uses port `1455`. ### Starting the Server @@ -90,6 +102,15 @@ Request body example: } ``` +Notes: +- Use a `gemini-*` model for Gemini (e.g., `gemini-2.5-pro`) or a `gpt-*` model for OpenAI (e.g., `gpt-5`). The proxy will route to the correct provider automatically. + +#### Claude Messages (SSE-compatible) + +``` +POST http://localhost:8317/v1/messages +``` + ### Using with OpenAI Libraries You can use this proxy with any OpenAI-compatible library by setting the base URL to your local server: @@ -104,14 +125,19 @@ client = OpenAI( base_url="http://localhost:8317/v1" ) -response = client.chat.completions.create( +# Gemini example +gemini = client.chat.completions.create( model="gemini-2.5-pro", - messages=[ - {"role": "user", "content": "Hello, how are you?"} - ] + messages=[{"role": "user", "content": "Hello, how are you?"}] ) -print(response.choices[0].message.content) +# Codex/GPT example +gpt = client.chat.completions.create( + model="gpt-5", + messages=[{"role": "user", "content": "Summarize this project in one sentence."}] +) +print(gemini.choices[0].message.content) +print(gpt.choices[0].message.content) ``` #### JavaScript/TypeScript @@ -124,28 +150,35 @@ const openai = new OpenAI({ baseURL: 'http://localhost:8317/v1', }); -const response = await openai.chat.completions.create({ +// Gemini +const gemini = await openai.chat.completions.create({ model: 'gemini-2.5-pro', - messages: [ - { role: 'user', content: 'Hello, how are you?' } - ], + messages: [{ role: 'user', content: 'Hello, how are you?' }], }); -console.log(response.choices[0].message.content); +// Codex/GPT +const gpt = await openai.chat.completions.create({ + model: 'gpt-5', + messages: [{ role: 'user', content: 'Summarize this project in one sentence.' }], +}); + +console.log(gemini.choices[0].message.content); +console.log(gpt.choices[0].message.content); ``` ## Supported Models - gemini-2.5-pro - gemini-2.5-flash -- And it automates switching to various preview versions +- gpt-5 +- Gemini models auto‑switch to preview variants when needed ## Configuration The server uses a YAML configuration file (`config.yaml`) located in the project root directory by default. You can specify a different configuration file path using the `--config` flag: ```bash -./cli-proxy --config /path/to/your/config.yaml +./cli-proxy-api --config /path/to/your/config.yaml ``` ### Configuration Options @@ -211,6 +244,10 @@ Authorization: Bearer your-api-key-1 The `generative-language-api-key` parameter allows you to define a list of API keys that can be used to authenticate requests to the official Generative Language API. +## Hot Reloading + +The server watches the config file and the `auth-dir` for changes and reloads clients and settings automatically. You can add or remove Gemini/OpenAI token JSON files while the server is running; no restart is required. + ## Gemini CLI with multiple account load balancing Start CLI Proxy API server, and then set the `CODE_ASSIST_ENDPOINT` environment variable to the URL of the CLI Proxy API server. @@ -227,12 +264,18 @@ The server will relay the `loadCodeAssist`, `onboardUser`, and `countTokens` req ## Run with Docker -Run the following command to login: +Run the following command to login (Gemini OAuth on port 8085): ```bash docker run --rm -p 8085:8085 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --login ``` +Run the following command to login (OpenAI OAuth on port 1455): + +```bash +docker run --rm -p 1455:1455 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --codex-login +``` + Run the following command to start the server: ```bash diff --git a/README_CN.md b/README_CN.md index 018697ea..206286e7 100644 --- a/README_CN.md +++ b/README_CN.md @@ -2,16 +2,21 @@ [English](README.md) | 中文 -一个为 CLI 提供 OpenAI/Gemini/Claude 兼容 API 接口的代理服务器。这让您可以摆脱终端界面的束缚,将 Gemini 的强大能力以 API 的形式轻松接入到任何您喜爱的客户端或应用中。 +一个为 CLI 提供 OpenAI/Gemini/Claude 兼容 API 接口的代理服务器。 + +现已支持通过 OAuth 登录接入 OpenAI Codex(GPT 系列)。 + +可与本地或多账户方式配合,使用任何 OpenAI 兼容的客户端与 SDK。 ## 功能特性 - 为 CLI 模型提供 OpenAI/Gemini/Claude 兼容的 API 端点 -- 支持流式和非流式响应 +- 新增 OpenAI Codex(GPT 系列)支持(OAuth 登录) +- 支持流式与非流式响应 - 函数调用/工具支持 -- 多模态输入支持(文本和图像) -- 多账户支持与负载均衡 -- 简单的 CLI 身份验证流程 +- 多模态输入(文本、图片) +- 多账户支持与轮询负载均衡(Gemini 与 OpenAI) +- 简单的 CLI 身份验证流程(Gemini 与 OpenAI) - 支持 Gemini AIStudio API 密钥 - 支持 Gemini CLI 多账户轮询 @@ -20,7 +25,8 @@ ### 前置要求 - Go 1.24 或更高版本 -- 有权访问 CLI 模型的 Google 账户 +- 有权访问 Gemini CLI 模型的 Google 账户(可选) +- 有权访问 OpenAI Codex/GPT 的 OpenAI 账户(可选) ### 从源码构建 @@ -39,17 +45,23 @@ ### 身份验证 -在使用 API 之前,您需要使用 Google 账户进行身份验证: +您可以分别为 Gemini 和 OpenAI 进行身份验证,二者可同时存在于同一个 `auth-dir` 中并参与负载均衡。 -```bash -./cli-proxy-api --login -``` +- Gemini(Google): + ```bash + ./cli-proxy-api --login + ``` + 如果您是旧版 gemini code 用户,可能需要指定项目 ID: + ```bash + ./cli-proxy-api --login --project_id + ``` + 本地 OAuth 回调端口为 `8085`。 -如果您是旧版 gemini code 用户,可能需要指定项目 ID: - -```bash -./cli-proxy-api --login --project_id -``` +- OpenAI(Codex/GPT,OAuth): + ```bash + ./cli-proxy-api --codex-login + ``` + 选项:加上 `--no-browser` 可打印登录地址而不自动打开浏览器。本地 OAuth 回调端口为 `1455`。 ### 启动服务器 @@ -90,6 +102,15 @@ POST http://localhost:8317/v1/chat/completions } ``` +说明: +- 使用 `gemini-*` 模型(如 `gemini-2.5-pro`)走 Gemini,使用 `gpt-*` 模型(如 `gpt-5`)走 OpenAI,服务会自动路由到对应提供商。 + +#### Claude 消息(SSE 兼容) + +``` +POST http://localhost:8317/v1/messages +``` + ### 与 OpenAI 库一起使用 您可以通过将基础 URL 设置为本地服务器来将此代理与任何 OpenAI 兼容的库一起使用: @@ -104,14 +125,20 @@ client = OpenAI( base_url="http://localhost:8317/v1" ) -response = client.chat.completions.create( +# Gemini 示例 +gemini = client.chat.completions.create( model="gemini-2.5-pro", - messages=[ - {"role": "user", "content": "你好,你好吗?"} - ] + messages=[{"role": "user", "content": "你好,你好吗?"}] ) -print(response.choices[0].message.content) +# Codex/GPT 示例 +gpt = client.chat.completions.create( + model="gpt-5", + messages=[{"role": "user", "content": "用一句话总结这个项目"}] +) + +print(gemini.choices[0].message.content) +print(gpt.choices[0].message.content) ``` #### JavaScript/TypeScript @@ -124,28 +151,35 @@ const openai = new OpenAI({ baseURL: 'http://localhost:8317/v1', }); -const response = await openai.chat.completions.create({ +// Gemini +const gemini = await openai.chat.completions.create({ model: 'gemini-2.5-pro', - messages: [ - { role: 'user', content: '你好,你好吗?' } - ], + messages: [{ role: 'user', content: '你好,你好吗?' }], }); -console.log(response.choices[0].message.content); +// Codex/GPT +const gpt = await openai.chat.completions.create({ + model: 'gpt-5', + messages: [{ role: 'user', content: '用一句话总结这个项目' }], +}); + +console.log(gemini.choices[0].message.content); +console.log(gpt.choices[0].message.content); ``` ## 支持的模型 - gemini-2.5-pro - gemini-2.5-flash -- 并且自动切换到之前的预览版本 +- gpt-5 +- Gemini 模型在需要时自动切换到对应的 preview 版本 ## 配置 服务器默认使用位于项目根目录的 YAML 配置文件(`config.yaml`)。您可以使用 `--config` 标志指定不同的配置文件路径: ```bash -./cli-proxy --config /path/to/your/config.yaml +./cli-proxy-api --config /path/to/your/config.yaml ``` ### 配置选项 @@ -211,6 +245,10 @@ Authorization: Bearer your-api-key-1 `generative-language-api-key` 参数允许您定义可用于验证对官方 AIStudio Gemini API 请求的 API 密钥列表。 +## 热更新 + +服务会监听配置文件与 `auth-dir` 目录的变化并自动重新加载客户端与配置。您可以在运行中新增/移除 Gemini/OpenAI 的令牌 JSON 文件,无需重启服务。 + ## Gemini CLI 多账户负载均衡 启动 CLI 代理 API 服务器,然后将 `CODE_ASSIST_ENDPOINT` 环境变量设置为 CLI 代理 API 服务器的 URL。 @@ -227,12 +265,18 @@ export CODE_ASSIST_ENDPOINT="http://127.0.0.1:8317" ## 使用 Docker 运行 -运行以下命令进行登录: +运行以下命令进行登录(Gemini OAuth,端口 8085): ```bash docker run --rm -p 8085:8085 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --login ``` +运行以下命令进行登录(OpenAI OAuth,端口 1455): + +```bash +docker run --rm -p 1455:1455 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --codex-login +``` + 运行以下命令启动服务器: ```bash @@ -251,4 +295,4 @@ docker run --rm -p 8317:8317 -v /path/to/your/config.yaml:/CLIProxyAPI/config.ya ## 许可证 -此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。 \ No newline at end of file +此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。 diff --git a/cmd/server/main.go b/cmd/server/main.go index cea5e6ee..c3b5c64e 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -7,19 +7,23 @@ import ( "bytes" "flag" "fmt" - "github.com/luispater/CLIProxyAPI/internal/cmd" - "github.com/luispater/CLIProxyAPI/internal/config" - log "github.com/sirupsen/logrus" "os" "path" "strings" + + "github.com/luispater/CLIProxyAPI/internal/cmd" + "github.com/luispater/CLIProxyAPI/internal/config" + log "github.com/sirupsen/logrus" ) // LogFormatter defines a custom log format for logrus. +// This formatter adds timestamp, log level, and source location information +// to each log entry for better debugging and monitoring. type LogFormatter struct { } -// Format renders a single log entry. +// Format renders a single log entry with custom formatting. +// It includes timestamp, log level, source file and line number, and the log message. func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) { var b *bytes.Buffer if entry.Buffer != nil { @@ -38,6 +42,8 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) { } // init initializes the logger configuration. +// It sets up the custom log formatter, enables caller reporting, +// and configures the log output destination. func init() { // Set logger output to standard output. log.SetOutput(os.Stdout) @@ -48,14 +54,20 @@ func init() { } // main is the entry point of the application. +// It parses command-line flags, loads configuration, and starts the appropriate +// service based on the provided flags (login, codex-login, or server mode). func main() { var login bool + var codexLogin bool + var noBrowser bool var projectID string var configPath string - // Define command-line flags. + // Define command-line flags for different operation modes. flag.BoolVar(&login, "login", false, "Login Google Account") - flag.StringVar(&projectID, "project_id", "", "Project ID") + flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth") + flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth") + flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)") flag.StringVar(&configPath, "config", "", "Configure File Path") // Parse the command-line flags. @@ -104,10 +116,19 @@ func main() { } } - // Either perform login or start the service based on the 'login' flag. + // Handle different command modes based on the provided flags. + options := &cmd.LoginOptions{ + NoBrowser: noBrowser, + } + if login { - cmd.DoLogin(cfg, projectID) + // Handle Google/Gemini login + cmd.DoLogin(cfg, projectID, options) + } else if codexLogin { + // Handle Codex login + cmd.DoCodexLogin(cfg, options) } else { + // Start the main proxy service cmd.StartService(cfg, configFilePath) } } diff --git a/config.example.yaml b/config.example.yaml index 552276d4..ed25b123 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1,15 +1,22 @@ +# Server configuration port: 8317 auth-dir: "~/.cli-proxy-api" debug: true proxy-url: "" + +# Quota exceeded behavior quota-exceeded: switch-project: true switch-preview-model: true + +# API keys for client authentication api-keys: - "12345" - "23456" + +# Generative language API keys generative-language-api-key: - "AIzaSy...01" - "AIzaSy...02" - "AIzaSy...03" - - "AIzaSy...04" + - "AIzaSy...04" \ No newline at end of file diff --git a/go.mod b/go.mod index 6c44518c..842cb74f 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,9 @@ module github.com/luispater/CLIProxyAPI go 1.24 require ( + github.com/fsnotify/fsnotify v1.9.0 github.com/gin-gonic/gin v1.10.1 + github.com/google/uuid v1.6.0 github.com/sirupsen/logrus v1.9.3 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/tidwall/gjson v1.18.0 @@ -19,7 +21,6 @@ require ( github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect - github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect diff --git a/go.sum b/go.sum index 4703cffb..68408afc 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,8 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= diff --git a/internal/api/handlers/claude/code-handlers.go b/internal/api/handlers/claude/code-handlers.go deleted file mode 100644 index 4427cc43..00000000 --- a/internal/api/handlers/claude/code-handlers.go +++ /dev/null @@ -1,208 +0,0 @@ -// Package claude provides HTTP handlers for Claude API code-related functionality. -// This package implements Claude-compatible streaming chat completions with sophisticated -// client rotation and quota management systems to ensure high availability and optimal -// resource utilization across multiple backend clients. It handles request translation -// between Claude API format and the underlying Gemini backend, providing seamless -// API compatibility while maintaining robust error handling and connection management. -package claude - -import ( - "context" - "fmt" - "github.com/gin-gonic/gin" - "github.com/luispater/CLIProxyAPI/internal/api/handlers" - "github.com/luispater/CLIProxyAPI/internal/api/translator/claude/code" - "github.com/luispater/CLIProxyAPI/internal/client" - log "github.com/sirupsen/logrus" - "net/http" - "strings" - "time" -) - -// ClaudeCodeAPIHandlers contains the handlers for Claude API endpoints. -// It holds a pool of clients to interact with the backend service. -type ClaudeCodeAPIHandlers struct { - *handlers.APIHandlers -} - -// NewClaudeCodeAPIHandlers creates a new Claude API handlers instance. -// It takes an APIHandlers instance as input and returns a ClaudeCodeAPIHandlers. -func NewClaudeCodeAPIHandlers(apiHandlers *handlers.APIHandlers) *ClaudeCodeAPIHandlers { - return &ClaudeCodeAPIHandlers{ - APIHandlers: apiHandlers, - } -} - -// ClaudeMessages handles Claude-compatible streaming chat completions. -// This function implements a sophisticated client rotation and quota management system -// to ensure high availability and optimal resource utilization across multiple backend clients. -func (h *ClaudeCodeAPIHandlers) ClaudeMessages(c *gin.Context) { - // Extract raw JSON data from the incoming request - rawJSON, err := c.GetRawData() - // If data retrieval fails, return a 400 Bad Request error. - if err != nil { - c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: fmt.Sprintf("Invalid request: %v", err), - Type: "invalid_request_error", - }, - }) - return - } - - // Set up Server-Sent Events (SSE) headers for streaming response - // These headers are essential for maintaining a persistent connection - // and enabling real-time streaming of chat completions - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - - // Get the http.Flusher interface to manually flush the response. - // This is crucial for streaming as it allows immediate sending of data chunks - flusher, ok := c.Writer.(http.Flusher) - if !ok { - c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: "Streaming not supported", - Type: "server_error", - }, - }) - return - } - - // Parse and prepare the Claude request, extracting model name, system instructions, - // conversation contents, and available tools from the raw JSON - modelName, systemInstruction, contents, tools := code.PrepareClaudeRequest(rawJSON) - - // Map Claude model names to corresponding Gemini models - // This allows the proxy to handle Claude API calls using Gemini backend - if modelName == "claude-sonnet-4-20250514" { - modelName = "gemini-2.5-pro" - } else if modelName == "claude-3-5-haiku-20241022" { - modelName = "gemini-2.5-flash" - } - - // Create a cancellable context for the backend client request - // This allows proper cleanup and cancellation of ongoing requests - cliCtx, cliCancel := context.WithCancel(context.Background()) - var cliClient *client.Client - defer func() { - // Ensure the client's mutex is unlocked on function exit. - // This prevents deadlocks and ensures proper resource cleanup - if cliClient != nil { - cliClient.RequestMutex.Unlock() - } - }() - - // Main client rotation loop with quota management - // This loop implements a sophisticated load balancing and failover mechanism -outLoop: - for { - var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.GetClient(modelName) - if errorResponse != nil { - c.Status(errorResponse.StatusCode) - _, _ = fmt.Fprint(c.Writer, errorResponse.Error) - flusher.Flush() - cliCancel() - return - } - - // Determine the authentication method being used by the selected client - // This affects how responses are formatted and logged - isGlAPIKey := false - if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { - log.Debugf("Request use generative language API Key: %s", glAPIKey) - isGlAPIKey = true - } else { - log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) - } - // Initiate streaming communication with the backend client - // This returns two channels: one for response chunks and one for errors - - includeThoughts := false - if userAgent, hasKey := c.Request.Header["User-Agent"]; hasKey { - includeThoughts = !strings.Contains(userAgent[0], "claude-cli") - } - - respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools, includeThoughts) - - // Track response state for proper Claude format conversion - hasFirstResponse := false - responseType := 0 - responseIndex := 0 - - // Main streaming loop - handles multiple concurrent events using Go channels - // This select statement manages four different types of events simultaneously - for { - select { - // Case 1: Handle client disconnection - // Detects when the HTTP client has disconnected and cleans up resources - case <-c.Request.Context().Done(): - if c.Request.Context().Err().Error() == "context canceled" { - log.Debugf("Client disconnected: %v", c.Request.Context().Err()) - cliCancel() // Cancel the backend request to prevent resource leaks - return - } - - // Case 2: Process incoming response chunks from the backend - // This handles the actual streaming data from the AI model - case chunk, okStream := <-respChan: - if !okStream { - // Stream has ended - send the final message_stop event - // This follows the Claude API specification for stream termination - _, _ = c.Writer.Write([]byte(`event: message_stop`)) - _, _ = c.Writer.Write([]byte("\n")) - _, _ = c.Writer.Write([]byte(`data: {"type":"message_stop"}`)) - _, _ = c.Writer.Write([]byte("\n\n\n")) - - flusher.Flush() - cliCancel() - return - } - // Convert the backend response to Claude-compatible format - // This translation layer ensures API compatibility - claudeFormat := code.ConvertCliToClaude(chunk, isGlAPIKey, hasFirstResponse, &responseType, &responseIndex) - if claudeFormat != "" { - _, _ = c.Writer.Write([]byte(claudeFormat)) - flusher.Flush() // Immediately send the chunk to the client - } - hasFirstResponse = true - - // Case 3: Handle errors from the backend - // This manages various error conditions and implements retry logic - case errInfo, okError := <-errChan: - if okError { - // Special handling for quota exceeded errors - // If configured, attempt to switch to a different project/client - if errInfo.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { - continue outLoop // Restart the client selection process - } else { - // Forward other errors directly to the client - c.Status(errInfo.StatusCode) - _, _ = fmt.Fprint(c.Writer, errInfo.Error.Error()) - flusher.Flush() - cliCancel() - } - return - } - - // Case 4: Send periodic keep-alive signals - // Prevents connection timeouts during long-running requests - case <-time.After(500 * time.Millisecond): - if hasFirstResponse { - // Send a ping event to maintain the connection - // This is especially important for slow AI model responses - output := "event: ping\n" - output = output + `data: {"type": "ping"}` - output = output + "\n\n\n" - _, _ = c.Writer.Write([]byte(output)) - - flusher.Flush() - } - } - } - } - -} diff --git a/internal/api/handlers/claude/code_handlers.go b/internal/api/handlers/claude/code_handlers.go new file mode 100644 index 00000000..7194e0e4 --- /dev/null +++ b/internal/api/handlers/claude/code_handlers.go @@ -0,0 +1,382 @@ +// Package claude provides HTTP handlers for Claude API code-related functionality. +// This package implements Claude-compatible streaming chat completions with sophisticated +// client rotation and quota management systems to ensure high availability and optimal +// resource utilization across multiple backend clients. It handles request translation +// between Claude API format and the underlying Gemini backend, providing seamless +// API compatibility while maintaining robust error handling and connection management. +package claude + +import ( + "bytes" + "context" + "fmt" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/luispater/CLIProxyAPI/internal/api/handlers" + "github.com/luispater/CLIProxyAPI/internal/client" + translatorClaudeCodeToCodex "github.com/luispater/CLIProxyAPI/internal/translator/codex/claude/code" + translatorClaudeCodeToGeminiCli "github.com/luispater/CLIProxyAPI/internal/translator/gemini-cli/claude/code" + "github.com/luispater/CLIProxyAPI/internal/util" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// ClaudeCodeAPIHandlers contains the handlers for Claude API endpoints. +// It holds a pool of clients to interact with the backend service. +type ClaudeCodeAPIHandlers struct { + *handlers.APIHandlers +} + +// NewClaudeCodeAPIHandlers creates a new Claude API handlers instance. +// It takes an APIHandlers instance as input and returns a ClaudeCodeAPIHandlers. +func NewClaudeCodeAPIHandlers(apiHandlers *handlers.APIHandlers) *ClaudeCodeAPIHandlers { + return &ClaudeCodeAPIHandlers{ + APIHandlers: apiHandlers, + } +} + +// ClaudeMessages handles Claude-compatible streaming chat completions. +// This function implements a sophisticated client rotation and quota management system +// to ensure high availability and optimal resource utilization across multiple backend clients. +func (h *ClaudeCodeAPIHandlers) ClaudeMessages(c *gin.Context) { + // Extract raw JSON data from the incoming request + rawJSON, err := c.GetRawData() + // If data retrieval fails, return a 400 Bad Request error. + if err != nil { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + + // h.handleGeminiStreamingResponse(c, rawJSON) + // h.handleCodexStreamingResponse(c, rawJSON) + modelName := gjson.GetBytes(rawJSON, "model") + provider := util.GetProviderName(modelName.String()) + if provider == "gemini" { + h.handleGeminiStreamingResponse(c, rawJSON) + } else if provider == "gpt" { + h.handleCodexStreamingResponse(c, rawJSON) + } else { + h.handleGeminiStreamingResponse(c, rawJSON) + } +} + +// handleGeminiStreamingResponse streams Claude-compatible responses backed by Gemini. +// It sets up SSE, selects a backend client with rotation/quota logic, +// forwards chunks, and translates them to Claude CLI format. +func (h *ClaudeCodeAPIHandlers) handleGeminiStreamingResponse(c *gin.Context, rawJSON []byte) { + // Set up Server-Sent Events (SSE) headers for streaming response + // These headers are essential for maintaining a persistent connection + // and enabling real-time streaming of chat completions + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + + // Get the http.Flusher interface to manually flush the response. + // This is crucial for streaming as it allows immediate sending of data chunks + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + // Parse and prepare the Claude request, extracting model name, system instructions, + // conversation contents, and available tools from the raw JSON + modelName, systemInstruction, contents, tools := translatorClaudeCodeToGeminiCli.ConvertClaudeCodeRequestToCli(rawJSON) + + // Map Claude model names to corresponding Gemini models + // This allows the proxy to handle Claude API calls using Gemini backend + if modelName == "claude-sonnet-4-20250514" { + modelName = "gemini-2.5-pro" + } else if modelName == "claude-3-5-haiku-20241022" { + modelName = "gemini-2.5-flash" + } + + // Create a cancellable context for the backend client request + // This allows proper cleanup and cancellation of ongoing requests + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient client.Client + cliClient = client.NewGeminiClient(nil, nil, nil) + defer func() { + // Ensure the client's mutex is unlocked on function exit. + // This prevents deadlocks and ensures proper resource cleanup + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + + // Main client rotation loop with quota management + // This loop implements a sophisticated load balancing and failover mechanism +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + flusher.Flush() + cliCancel() + return + } + + // Determine the authentication method being used by the selected client + // This affects how responses are formatted and logged + isGlAPIKey := false + if glAPIKey := cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use gemini generative language API Key: %s", glAPIKey) + isGlAPIKey = true + } else { + log.Debugf("Request use gemini account: %s, project id: %s", cliClient.GetEmail(), cliClient.(*client.GeminiClient).GetProjectID()) + } + // Initiate streaming communication with the backend client + // This returns two channels: one for response chunks and one for errors + + includeThoughts := false + if userAgent, hasKey := c.Request.Header["User-Agent"]; hasKey { + includeThoughts = !strings.Contains(userAgent[0], "claude-cli") + } + + respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools, includeThoughts) + + // Track response state for proper Claude format conversion + hasFirstResponse := false + responseType := 0 + responseIndex := 0 + + // Main streaming loop - handles multiple concurrent events using Go channels + // This select statement manages four different types of events simultaneously + for { + select { + // Case 1: Handle client disconnection + // Detects when the HTTP client has disconnected and cleans up resources + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("GeminiClient disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request to prevent resource leaks + return + } + + // Case 2: Process incoming response chunks from the backend + // This handles the actual streaming data from the AI model + case chunk, okStream := <-respChan: + if !okStream { + // Stream has ended - send the final message_stop event + // This follows the Claude API specification for stream termination + _, _ = c.Writer.Write([]byte(`event: message_stop`)) + _, _ = c.Writer.Write([]byte("\n")) + _, _ = c.Writer.Write([]byte(`data: {"type":"message_stop"}`)) + _, _ = c.Writer.Write([]byte("\n\n\n")) + + flusher.Flush() + cliCancel() + return + } + // Convert the backend response to Claude-compatible format + // This translation layer ensures API compatibility + claudeFormat := translatorClaudeCodeToGeminiCli.ConvertCliResponseToClaudeCode(chunk, isGlAPIKey, hasFirstResponse, &responseType, &responseIndex) + if claudeFormat != "" { + _, _ = c.Writer.Write([]byte(claudeFormat)) + flusher.Flush() // Immediately send the chunk to the client + } + hasFirstResponse = true + + // Case 3: Handle errors from the backend + // This manages various error conditions and implements retry logic + case errInfo, okError := <-errChan: + if okError { + // Special handling for quota exceeded errors + // If configured, attempt to switch to a different project/client + if errInfo.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue outLoop // Restart the client selection process + } else { + // Forward other errors directly to the client + c.Status(errInfo.StatusCode) + _, _ = fmt.Fprint(c.Writer, errInfo.Error.Error()) + flusher.Flush() + cliCancel() + } + return + } + + // Case 4: Send periodic keep-alive signals + // Prevents connection timeouts during long-running requests + case <-time.After(500 * time.Millisecond): + if hasFirstResponse { + // Send a ping event to maintain the connection + // This is especially important for slow AI model responses + output := "event: ping\n" + output = output + `data: {"type": "ping"}` + output = output + "\n\n\n" + _, _ = c.Writer.Write([]byte(output)) + + flusher.Flush() + } + } + } + } +} + +// handleCodexStreamingResponse streams Claude-compatible responses backed by OpenAI. +// It converts the Claude request into Codex/OpenAI responses format, establishes SSE, +// and translates streaming chunks back into Claude CLI events. +func (h *ClaudeCodeAPIHandlers) handleCodexStreamingResponse(c *gin.Context, rawJSON []byte) { + // Set up Server-Sent Events (SSE) headers for streaming response + // These headers are essential for maintaining a persistent connection + // and enabling real-time streaming of chat completions + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + + // Get the http.Flusher interface to manually flush the response. + // This is crucial for streaming as it allows immediate sending of data chunks + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + // Parse and prepare the Claude request, extracting model name, system instructions, + // conversation contents, and available tools from the raw JSON + newRequestJSON := translatorClaudeCodeToCodex.ConvertClaudeCodeRequestToCodex(rawJSON) + modelName := gjson.GetBytes(rawJSON, "model").String() + // Map Claude model names to corresponding Gemini models + // This allows the proxy to handle Claude API calls using Gemini backend + if modelName == "claude-sonnet-4-20250514" { + modelName = "gpt-5" + } else if modelName == "claude-3-5-haiku-20241022" { + modelName = "gpt-5" + } + newRequestJSON, _ = sjson.Set(newRequestJSON, "model", modelName) + // log.Debugf(string(rawJSON)) + // log.Debugf(newRequestJSON) + // return + // Create a cancellable context for the backend client request + // This allows proper cleanup and cancellation of ongoing requests + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient client.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + // This prevents deadlocks and ensures proper resource cleanup + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + + // Main client rotation loop with quota management + // This loop implements a sophisticated load balancing and failover mechanism +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + flusher.Flush() + cliCancel() + return + } + + log.Debugf("Request use codex account: %s", cliClient.GetEmail()) + + // Initiate streaming communication with the backend client + // This returns two channels: one for response chunks and one for errors + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "") + + // Track response state for proper Claude format conversion + hasFirstResponse := false + hasToolCall := false + + // Main streaming loop - handles multiple concurrent events using Go channels + // This select statement manages four different types of events simultaneously + for { + select { + // Case 1: Handle client disconnection + // Detects when the HTTP client has disconnected and cleans up resources + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request to prevent resource leaks + return + } + + // Case 2: Process incoming response chunks from the backend + // This handles the actual streaming data from the AI model + case chunk, okStream := <-respChan: + if !okStream { + flusher.Flush() + cliCancel() + return + } + // Convert the backend response to Claude-compatible format + // This translation layer ensures API compatibility + if bytes.HasPrefix(chunk, []byte("data: ")) { + jsonData := chunk[6:] + var claudeFormat string + claudeFormat, hasToolCall = translatorClaudeCodeToCodex.ConvertCodexResponseToClaude(jsonData, hasToolCall) + // log.Debugf("claudeFormat: %s", claudeFormat) + if claudeFormat != "" { + _, _ = c.Writer.Write([]byte(claudeFormat)) + _, _ = c.Writer.Write([]byte("\n")) + } + flusher.Flush() // Immediately send the chunk to the client + hasFirstResponse = true + } else { + // log.Debugf("chunk: %s", string(chunk)) + } + // Case 3: Handle errors from the backend + // This manages various error conditions and implements retry logic + case errInfo, okError := <-errChan: + if okError { + // log.Debugf("Code: %d, Error: %v", errInfo.StatusCode, errInfo.Error) + // Special handling for quota exceeded errors + // If configured, attempt to switch to a different project/client + if errInfo.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + log.Debugf("quota exceeded, switch client") + continue outLoop // Restart the client selection process + } else { + // Forward other errors directly to the client + c.Status(errInfo.StatusCode) + _, _ = fmt.Fprint(c.Writer, errInfo.Error.Error()) + flusher.Flush() + cliCancel() + } + return + } + + // Case 4: Send periodic keep-alive signals + // Prevents connection timeouts during long-running requests + case <-time.After(3000 * time.Millisecond): + if hasFirstResponse { + // Send a ping event to maintain the connection + // This is especially important for slow AI model responses + output := "event: ping\n" + output = output + `data: {"type": "ping"}` + output = output + "\n\n" + _, _ = c.Writer.Write([]byte(output)) + + flusher.Flush() + } + } + } + } +} diff --git a/internal/api/handlers/gemini/cli/cli-handlers.go b/internal/api/handlers/gemini/cli/cli-handlers.go deleted file mode 100644 index 7c1dde77..00000000 --- a/internal/api/handlers/gemini/cli/cli-handlers.go +++ /dev/null @@ -1,268 +0,0 @@ -// Package cli provides HTTP handlers for Gemini CLI API functionality. -// This package implements handlers that process CLI-specific requests for Gemini API operations, -// including content generation and streaming content generation endpoints. -// The handlers restrict access to localhost only and manage communication with the backend service. -package cli - -import ( - "bytes" - "context" - "fmt" - "github.com/gin-gonic/gin" - "github.com/luispater/CLIProxyAPI/internal/api/handlers" - "github.com/luispater/CLIProxyAPI/internal/client" - "github.com/luispater/CLIProxyAPI/internal/util" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - "io" - "net/http" - "strings" - "time" -) - -// GeminiCLIAPIHandlers contains the handlers for Gemini CLI API endpoints. -// It holds a pool of clients to interact with the backend service. -type GeminiCLIAPIHandlers struct { - *handlers.APIHandlers -} - -// NewGeminiCLIAPIHandlers creates a new Gemini CLI API handlers instance. -// It takes an APIHandlers instance as input and returns a GeminiCLIAPIHandlers. -func NewGeminiCLIAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiCLIAPIHandlers { - return &GeminiCLIAPIHandlers{ - APIHandlers: apiHandlers, - } -} - -// CLIHandler handles CLI-specific requests for Gemini API operations. -// It restricts access to localhost only and routes requests to appropriate internal handlers. -func (h *GeminiCLIAPIHandlers) CLIHandler(c *gin.Context) { - if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") { - c.JSON(http.StatusForbidden, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: "CLI reply only allow local access", - Type: "forbidden", - }, - }) - return - } - - rawJSON, _ := c.GetRawData() - requestRawURI := c.Request.URL.Path - if requestRawURI == "/v1internal:generateContent" { - h.internalGenerateContent(c, rawJSON) - } else if requestRawURI == "/v1internal:streamGenerateContent" { - h.internalStreamGenerateContent(c, rawJSON) - } else { - reqBody := bytes.NewBuffer(rawJSON) - req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody) - if err != nil { - c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: fmt.Sprintf("Invalid request: %v", err), - Type: "invalid_request_error", - }, - }) - return - } - for key, value := range c.Request.Header { - req.Header[key] = value - } - - httpClient, err := util.SetProxy(h.Cfg, &http.Client{}) - if err != nil { - log.Fatalf("set proxy failed: %v", err) - } - - resp, err := httpClient.Do(req) - if err != nil { - c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: fmt.Sprintf("Invalid request: %v", err), - Type: "invalid_request_error", - }, - }) - return - } - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - defer func() { - if err = resp.Body.Close(); err != nil { - log.Printf("warn: failed to close response body: %v", err) - } - }() - bodyBytes, _ := io.ReadAll(resp.Body) - - c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: string(bodyBytes), - Type: "invalid_request_error", - }, - }) - return - } - - defer func() { - _ = resp.Body.Close() - }() - - for key, value := range resp.Header { - c.Header(key, value[0]) - } - output, err := io.ReadAll(resp.Body) - if err != nil { - log.Errorf("Failed to read response body: %v", err) - return - } - _, _ = c.Writer.Write(output) - } -} - -func (h *GeminiCLIAPIHandlers) internalStreamGenerateContent(c *gin.Context, rawJSON []byte) { - alt := h.GetAlt(c) - - if alt == "" { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - } - - // Get the http.Flusher interface to manually flush the response. - flusher, ok := c.Writer.(http.Flusher) - if !ok { - c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: "Streaming not supported", - Type: "server_error", - }, - }) - return - } - - modelResult := gjson.GetBytes(rawJSON, "model") - modelName := modelResult.String() - - cliCtx, cliCancel := context.WithCancel(context.Background()) - var cliClient *client.Client - defer func() { - // Ensure the client's mutex is unlocked on function exit. - if cliClient != nil { - cliClient.RequestMutex.Unlock() - } - }() - -outLoop: - for { - var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.GetClient(modelName) - if errorResponse != nil { - c.Status(errorResponse.StatusCode) - _, _ = fmt.Fprint(c.Writer, errorResponse.Error) - flusher.Flush() - cliCancel() - return - } - - if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { - log.Debugf("Request use generative language API Key: %s", glAPIKey) - } else { - log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) - } - // Send the message and receive response chunks and errors via channels. - respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, "") - hasFirstResponse := false - for { - select { - // Handle client disconnection. - case <-c.Request.Context().Done(): - if c.Request.Context().Err().Error() == "context canceled" { - log.Debugf("Client disconnected: %v", c.Request.Context().Err()) - cliCancel() // Cancel the backend request. - return - } - // Process incoming response chunks. - case chunk, okStream := <-respChan: - if !okStream { - cliCancel() - return - } - hasFirstResponse = true - if cliClient.GetGenerativeLanguageAPIKey() != "" { - chunk, _ = sjson.SetRawBytes(chunk, "response", chunk) - } - _, _ = c.Writer.Write([]byte("data: ")) - _, _ = c.Writer.Write(chunk) - _, _ = c.Writer.Write([]byte("\n\n")) - flusher.Flush() - // Handle errors from the backend. - case err, okError := <-errChan: - if okError { - if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { - continue outLoop - } else { - c.Status(err.StatusCode) - _, _ = fmt.Fprint(c.Writer, err.Error.Error()) - flusher.Flush() - cliCancel() - } - return - } - // Send a keep-alive signal to the client. - case <-time.After(500 * time.Millisecond): - if hasFirstResponse { - _, _ = c.Writer.Write([]byte("\n")) - flusher.Flush() - } - } - } - } -} - -func (h *GeminiCLIAPIHandlers) internalGenerateContent(c *gin.Context, rawJSON []byte) { - c.Header("Content-Type", "application/json") - - modelResult := gjson.GetBytes(rawJSON, "model") - modelName := modelResult.String() - cliCtx, cliCancel := context.WithCancel(context.Background()) - var cliClient *client.Client - defer func() { - if cliClient != nil { - cliClient.RequestMutex.Unlock() - } - }() - - for { - var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.GetClient(modelName) - if errorResponse != nil { - c.Status(errorResponse.StatusCode) - _, _ = fmt.Fprint(c.Writer, errorResponse.Error) - cliCancel() - return - } - - if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { - log.Debugf("Request use generative language API Key: %s", glAPIKey) - } else { - log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) - } - - resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, "") - if err != nil { - if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { - continue - } else { - c.Status(err.StatusCode) - _, _ = c.Writer.Write([]byte(err.Error.Error())) - cliCancel() - } - break - } else { - _, _ = c.Writer.Write(resp) - cliCancel() - break - } - } -} diff --git a/internal/api/handlers/gemini/cli/cli_handlers.go b/internal/api/handlers/gemini/cli/cli_handlers.go new file mode 100644 index 00000000..5100b968 --- /dev/null +++ b/internal/api/handlers/gemini/cli/cli_handlers.go @@ -0,0 +1,491 @@ +// Package cli provides HTTP handlers for Gemini CLI API functionality. +// This package implements handlers that process CLI-specific requests for Gemini API operations, +// including content generation and streaming content generation endpoints. +// The handlers restrict access to localhost only and manage communication with the backend service. +package cli + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/luispater/CLIProxyAPI/internal/api/handlers" + "github.com/luispater/CLIProxyAPI/internal/client" + translatorGeminiToCodex "github.com/luispater/CLIProxyAPI/internal/translator/codex/gemini" + "github.com/luispater/CLIProxyAPI/internal/util" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// GeminiCLIAPIHandlers contains the handlers for Gemini CLI API endpoints. +// It holds a pool of clients to interact with the backend service. +type GeminiCLIAPIHandlers struct { + *handlers.APIHandlers +} + +// NewGeminiCLIAPIHandlers creates a new Gemini CLI API handlers instance. +// It takes an APIHandlers instance as input and returns a GeminiCLIAPIHandlers. +func NewGeminiCLIAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiCLIAPIHandlers { + return &GeminiCLIAPIHandlers{ + APIHandlers: apiHandlers, + } +} + +// CLIHandler handles CLI-specific requests for Gemini API operations. +// It restricts access to localhost only and routes requests to appropriate internal handlers. +func (h *GeminiCLIAPIHandlers) CLIHandler(c *gin.Context) { + if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") { + c.JSON(http.StatusForbidden, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "CLI reply only allow local access", + Type: "forbidden", + }, + }) + return + } + + rawJSON, _ := c.GetRawData() + requestRawURI := c.Request.URL.Path + + modelName := gjson.GetBytes(rawJSON, "model") + provider := util.GetProviderName(modelName.String()) + + if requestRawURI == "/v1internal:generateContent" { + if provider == "gemini" || provider == "unknow" { + h.handleInternalGenerateContent(c, rawJSON) + } else if provider == "gpt" { + h.handleCodexInternalGenerateContent(c, rawJSON) + } + } else if requestRawURI == "/v1internal:streamGenerateContent" { + if provider == "gemini" || provider == "unknow" { + h.handleInternalStreamGenerateContent(c, rawJSON) + } else if provider == "gpt" { + h.handleCodexInternalStreamGenerateContent(c, rawJSON) + } + } else { + reqBody := bytes.NewBuffer(rawJSON) + req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody) + if err != nil { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + for key, value := range c.Request.Header { + req.Header[key] = value + } + + httpClient := util.SetProxy(h.Cfg, &http.Client{}) + + resp, err := httpClient.Do(req) + if err != nil { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + defer func() { + if err = resp.Body.Close(); err != nil { + log.Printf("warn: failed to close response body: %v", err) + } + }() + bodyBytes, _ := io.ReadAll(resp.Body) + + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: string(bodyBytes), + Type: "invalid_request_error", + }, + }) + return + } + + defer func() { + _ = resp.Body.Close() + }() + + for key, value := range resp.Header { + c.Header(key, value[0]) + } + output, err := io.ReadAll(resp.Body) + if err != nil { + log.Errorf("Failed to read response body: %v", err) + return + } + _, _ = c.Writer.Write(output) + } +} + +func (h *GeminiCLIAPIHandlers) handleInternalStreamGenerateContent(c *gin.Context, rawJSON []byte) { + alt := h.GetAlt(c) + + if alt == "" { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + // Get the http.Flusher interface to manually flush the response. + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + modelResult := gjson.GetBytes(rawJSON, "model") + modelName := modelResult.String() + + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient client.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + flusher.Flush() + cliCancel() + return + } + + if glAPIKey := cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use generative language API Key: %s", glAPIKey) + } else { + log.Debugf("Request cli use account: %s, project id: %s", cliClient.(*client.GeminiClient).GetEmail(), cliClient.(*client.GeminiClient).GetProjectID()) + } + // Send the message and receive response chunks and errors via channels. + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, "") + hasFirstResponse := false + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("GeminiClient disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + cliCancel() + return + } + hasFirstResponse = true + if cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey() != "" { + chunk, _ = sjson.SetRawBytes(chunk, "response", chunk) + } + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write(chunk) + _, _ = c.Writer.Write([]byte("\n\n")) + + flusher.Flush() + // Handle errors from the backend. + case err, okError := <-errChan: + if okError { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue outLoop + } else { + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + flusher.Flush() + cliCancel() + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + if hasFirstResponse { + _, _ = c.Writer.Write([]byte("\n")) + flusher.Flush() + } + } + } + } +} + +func (h *GeminiCLIAPIHandlers) handleInternalGenerateContent(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "application/json") + // log.Debugf("GenerateContent: %s", string(rawJSON)) + modelResult := gjson.GetBytes(rawJSON, "model") + modelName := modelResult.String() + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient client.Client + defer func() { + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + cliCancel() + return + } + + if glAPIKey := cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use generative language API Key: %s", glAPIKey) + } else { + log.Debugf("Request cli use account: %s, project id: %s", cliClient.(*client.GeminiClient).GetEmail(), cliClient.(*client.GeminiClient).GetProjectID()) + } + + resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, "") + if err != nil { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue + } else { + c.Status(err.StatusCode) + _, _ = c.Writer.Write([]byte(err.Error.Error())) + log.Debugf("code: %d, error: %s", err.StatusCode, err.Error.Error()) + cliCancel() + } + break + } else { + _, _ = c.Writer.Write(resp) + cliCancel() + break + } + } +} + +func (h *GeminiCLIAPIHandlers) handleCodexInternalStreamGenerateContent(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + + // Get the http.Flusher interface to manually flush the response. + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + modelResult := gjson.GetBytes(rawJSON, "model") + rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw) + rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelResult.String()) + rawJSON, _ = sjson.SetRawBytes(rawJSON, "system_instruction", []byte(gjson.GetBytes(rawJSON, "systemInstruction").Raw)) + rawJSON, _ = sjson.DeleteBytes(rawJSON, "systemInstruction") + + // log.Debugf("Request: %s", string(rawJSON)) + // return + + // Prepare the request for the backend client. + newRequestJSON := translatorGeminiToCodex.ConvertGeminiRequestToCodex(rawJSON) + // log.Debugf("Request: %s", newRequestJSON) + + modelName := gjson.GetBytes(rawJSON, "model") + + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient client.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName.String()) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + flusher.Flush() + cliCancel() + return + } + + log.Debugf("Request codex use account: %s", cliClient.GetEmail()) + + // Send the message and receive response chunks and errors via channels. + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "") + + params := &translatorGeminiToCodex.ConvertCodexResponseToGeminiParams{ + Model: modelName.String(), + CreatedAt: 0, + ResponseID: "", + LastStorageOutput: "", + } + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + cliCancel() + return + } + // _, _ = logFile.Write(chunk) + // _, _ = logFile.Write([]byte("\n")) + + if bytes.HasPrefix(chunk, []byte("data: ")) { + jsonData := chunk[6:] + data := gjson.ParseBytes(jsonData) + typeResult := data.Get("type") + if typeResult.String() != "" { + outputs := translatorGeminiToCodex.ConvertCodexResponseToGemini(jsonData, params) + if len(outputs) > 0 { + for i := 0; i < len(outputs); i++ { + outputs[i], _ = sjson.SetRaw("{}", "response", outputs[i]) + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write([]byte(outputs[i])) + _, _ = c.Writer.Write([]byte("\n\n")) + } + } + } + } + flusher.Flush() + // Handle errors from the backend. + case errMessage, okError := <-errChan: + if okError { + if errMessage.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue outLoop + } else { + log.Debugf("code: %d, error: %s", errMessage.StatusCode, errMessage.Error.Error()) + c.Status(errMessage.StatusCode) + _, _ = fmt.Fprint(c.Writer, errMessage.Error.Error()) + flusher.Flush() + cliCancel() + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + } + } + } +} + +func (h *GeminiCLIAPIHandlers) handleCodexInternalGenerateContent(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "application/json") + orgRawJSON := rawJSON + modelResult := gjson.GetBytes(rawJSON, "model") + rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw) + rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelResult.String()) + rawJSON, _ = sjson.SetRawBytes(rawJSON, "system_instruction", []byte(gjson.GetBytes(rawJSON, "systemInstruction").Raw)) + rawJSON, _ = sjson.DeleteBytes(rawJSON, "systemInstruction") + + // Prepare the request for the backend client. + newRequestJSON := translatorGeminiToCodex.ConvertGeminiRequestToCodex(rawJSON) + // log.Debugf("Request: %s", newRequestJSON) + + modelName := gjson.GetBytes(rawJSON, "model") + + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient client.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName.String()) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + cliCancel() + return + } + + log.Debugf("Request codex use account: %s", cliClient.GetEmail()) + + // Send the message and receive response chunks and errors via channels. + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "") + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + cliCancel() + return + } + + if bytes.HasPrefix(chunk, []byte("data: ")) { + jsonData := chunk[6:] + data := gjson.ParseBytes(jsonData) + typeResult := data.Get("type") + if typeResult.String() != "" { + var geminiStr string + geminiStr = translatorGeminiToCodex.ConvertCodexResponseToGeminiNonStream(jsonData, modelName.String()) + if geminiStr != "" { + _, _ = c.Writer.Write([]byte(geminiStr)) + } + } + } + // Handle errors from the backend. + case err, okError := <-errChan: + if okError { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue outLoop + } else { + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + log.Debugf("org: %s", string(orgRawJSON)) + log.Debugf("raw: %s", string(rawJSON)) + log.Debugf("newRequestJSON: %s", newRequestJSON) + cliCancel() + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + } + } + } +} diff --git a/internal/api/handlers/gemini/gemini-handlers.go b/internal/api/handlers/gemini/gemini-handlers.go deleted file mode 100644 index 160b5daf..00000000 --- a/internal/api/handlers/gemini/gemini-handlers.go +++ /dev/null @@ -1,437 +0,0 @@ -// Package gemini provides HTTP handlers for Gemini API endpoints. -// This package implements handlers for managing Gemini model operations including -// model listing, content generation, streaming content generation, and token counting. -// It serves as a proxy layer between clients and the Gemini backend service, -// handling request translation, client management, and response processing. -package gemini - -import ( - "context" - "fmt" - "github.com/gin-gonic/gin" - "github.com/luispater/CLIProxyAPI/internal/api/handlers" - "github.com/luispater/CLIProxyAPI/internal/api/translator/gemini/cli" - "github.com/luispater/CLIProxyAPI/internal/client" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - "net/http" - "strings" - "time" -) - -// GeminiAPIHandlers contains the handlers for Gemini API endpoints. -// It holds a pool of clients to interact with the backend service. -type GeminiAPIHandlers struct { - *handlers.APIHandlers -} - -// NewGeminiAPIHandlers creates a new Gemini API handlers instance. -// It takes an APIHandlers instance as input and returns a GeminiAPIHandlers. -func NewGeminiAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiAPIHandlers { - return &GeminiAPIHandlers{ - APIHandlers: apiHandlers, - } -} - -// GeminiModels handles the Gemini models listing endpoint. -// It returns a JSON response containing available Gemini models and their specifications. -func (h *GeminiAPIHandlers) GeminiModels(c *gin.Context) { - c.Status(http.StatusOK) - c.Header("Content-Type", "application/json; charset=UTF-8") - _, _ = c.Writer.Write([]byte(`{"models":[{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini `)) - _, _ = c.Writer.Write([]byte(`2.5 Flash","description":"Stable version of Gemini 2.5 Flash, our mid-size multimod`)) - _, _ = c.Writer.Write([]byte(`al model that supports up to 1 million tokens, released in June of 2025.","inputTok`)) - _, _ = c.Writer.Write([]byte(`enLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateCo`)) - _, _ = c.Writer.Write([]byte(`ntent","countTokens","createCachedContent","batchGenerateContent"],"temperature":1,`)) - _, _ = c.Writer.Write([]byte(`"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true},{"name":"models/gemini-2.`)) - _, _ = c.Writer.Write([]byte(`5-pro","version":"2.5","displayName":"Gemini 2.5 Pro","description":"Stable release`)) - _, _ = c.Writer.Write([]byte(` (June 17th, 2025) of Gemini 2.5 Pro","inputTokenLimit":1048576,"outputTokenLimit":`)) - _, _ = c.Writer.Write([]byte(`65536,"supportedGenerationMethods":["generateContent","countTokens","createCachedCo`)) - _, _ = c.Writer.Write([]byte(`ntent","batchGenerateContent"],"temperature":1,"topP":0.95,"topK":64,"maxTemperatur`)) - _, _ = c.Writer.Write([]byte(`e":2,"thinking":true}],"nextPageToken":""}`)) -} - -// GeminiGetHandler handles GET requests for specific Gemini model information. -// It returns detailed information about a specific Gemini model based on the action parameter. -func (h *GeminiAPIHandlers) GeminiGetHandler(c *gin.Context) { - var request struct { - Action string `uri:"action" binding:"required"` - } - if err := c.ShouldBindUri(&request); err != nil { - c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: fmt.Sprintf("Invalid request: %v", err), - Type: "invalid_request_error", - }, - }) - return - } - if request.Action == "gemini-2.5-pro" { - c.Status(http.StatusOK) - c.Header("Content-Type", "application/json; charset=UTF-8") - _, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-pro","version":"2.5","displayName":"Gemini 2.5 Pro",`)) - _, _ = c.Writer.Write([]byte(`"description":"Stable release (June 17th, 2025) of Gemini 2.5 Pro","inputTokenL`)) - _, _ = c.Writer.Write([]byte(`imit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateC`)) - _, _ = c.Writer.Write([]byte(`ontent","countTokens","createCachedContent","batchGenerateContent"],"temperatur`)) - _, _ = c.Writer.Write([]byte(`e":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`)) - } else if request.Action == "gemini-2.5-flash" { - c.Status(http.StatusOK) - c.Header("Content-Type", "application/json; charset=UTF-8") - _, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini 2.5 Fla`)) - _, _ = c.Writer.Write([]byte(`sh","description":"Stable version of Gemini 2.5 Flash, our mid-size multimodal `)) - _, _ = c.Writer.Write([]byte(`model that supports up to 1 million tokens, released in June of 2025.","inputTo`)) - _, _ = c.Writer.Write([]byte(`kenLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["gener`)) - _, _ = c.Writer.Write([]byte(`ateContent","countTokens","createCachedContent","batchGenerateContent"],"temper`)) - _, _ = c.Writer.Write([]byte(`ature":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`)) - } else { - c.Status(http.StatusNotFound) - _, _ = c.Writer.Write([]byte( - `{"error":{"message":"Not Found","code":404,"status":"NOT_FOUND"}}`, - )) - } -} - -// GeminiHandler handles POST requests for Gemini API operations. -// It routes requests to appropriate handlers based on the action parameter (model:method format). -func (h *GeminiAPIHandlers) GeminiHandler(c *gin.Context) { - var request struct { - Action string `uri:"action" binding:"required"` - } - if err := c.ShouldBindUri(&request); err != nil { - c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: fmt.Sprintf("Invalid request: %v", err), - Type: "invalid_request_error", - }, - }) - return - } - action := strings.Split(request.Action, ":") - if len(action) != 2 { - c.JSON(http.StatusNotFound, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: fmt.Sprintf("%s not found.", c.Request.URL.Path), - Type: "invalid_request_error", - }, - }) - return - } - - modelName := action[0] - method := action[1] - rawJSON, _ := c.GetRawData() - rawJSON, _ = sjson.SetBytes(rawJSON, "model", []byte(modelName)) - - if method == "generateContent" { - h.geminiGenerateContent(c, rawJSON) - } else if method == "streamGenerateContent" { - h.geminiStreamGenerateContent(c, rawJSON) - } else if method == "countTokens" { - h.geminiCountTokens(c, rawJSON) - } -} - -func (h *GeminiAPIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJSON []byte) { - alt := h.GetAlt(c) - - if alt == "" { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - } - - // Get the http.Flusher interface to manually flush the response. - flusher, ok := c.Writer.(http.Flusher) - if !ok { - c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: "Streaming not supported", - Type: "server_error", - }, - }) - return - } - - modelResult := gjson.GetBytes(rawJSON, "model") - modelName := modelResult.String() - - cliCtx, cliCancel := context.WithCancel(context.Background()) - var cliClient *client.Client - defer func() { - // Ensure the client's mutex is unlocked on function exit. - if cliClient != nil { - cliClient.RequestMutex.Unlock() - } - }() - -outLoop: - for { - var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.GetClient(modelName) - if errorResponse != nil { - c.Status(errorResponse.StatusCode) - _, _ = fmt.Fprint(c.Writer, errorResponse.Error) - flusher.Flush() - cliCancel() - return - } - - template := "" - parsed := gjson.Parse(string(rawJSON)) - contents := parsed.Get("request.contents") - if contents.Exists() { - template = string(rawJSON) - } else { - template = `{"project":"","request":{},"model":""}` - template, _ = sjson.SetRaw(template, "request", string(rawJSON)) - template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String()) - template, _ = sjson.Delete(template, "request.model") - } - - template, errFixCLIToolResponse := cli.FixCLIToolResponse(template) - if errFixCLIToolResponse != nil { - c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: errFixCLIToolResponse.Error(), - Type: "server_error", - }, - }) - cliCancel() - return - } - - systemInstructionResult := gjson.Get(template, "request.system_instruction") - if systemInstructionResult.Exists() { - template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw) - template, _ = sjson.Delete(template, "request.system_instruction") - } - rawJSON = []byte(template) - - if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { - log.Debugf("Request use generative language API Key: %s", glAPIKey) - } else { - log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) - } - - // Send the message and receive response chunks and errors via channels. - respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, alt) - for { - select { - // Handle client disconnection. - case <-c.Request.Context().Done(): - if c.Request.Context().Err().Error() == "context canceled" { - log.Debugf("Client disconnected: %v", c.Request.Context().Err()) - cliCancel() // Cancel the backend request. - return - } - // Process incoming response chunks. - case chunk, okStream := <-respChan: - if !okStream { - cliCancel() - return - } - if cliClient.GetGenerativeLanguageAPIKey() == "" { - if alt == "" { - responseResult := gjson.GetBytes(chunk, "response") - if responseResult.Exists() { - chunk = []byte(responseResult.Raw) - } - } else { - chunkTemplate := "[]" - responseResult := gjson.ParseBytes(chunk) - if responseResult.IsArray() { - responseResultItems := responseResult.Array() - for i := 0; i < len(responseResultItems); i++ { - responseResultItem := responseResultItems[i] - if responseResultItem.Get("response").Exists() { - chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw) - } - } - } - chunk = []byte(chunkTemplate) - } - } - if alt == "" { - _, _ = c.Writer.Write([]byte("data: ")) - _, _ = c.Writer.Write(chunk) - _, _ = c.Writer.Write([]byte("\n\n")) - } else { - _, _ = c.Writer.Write(chunk) - } - flusher.Flush() - // Handle errors from the backend. - case err, okError := <-errChan: - if okError { - if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { - log.Debugf("quota exceeded, switch client") - continue outLoop - } else { - log.Debugf("error code :%d, error: %v", err.StatusCode, err.Error.Error()) - c.Status(err.StatusCode) - _, _ = fmt.Fprint(c.Writer, err.Error.Error()) - flusher.Flush() - cliCancel() - } - return - } - // Send a keep-alive signal to the client. - case <-time.After(500 * time.Millisecond): - } - } - } -} - -func (h *GeminiAPIHandlers) geminiCountTokens(c *gin.Context, rawJSON []byte) { - c.Header("Content-Type", "application/json") - - alt := h.GetAlt(c) - // orgrawJSON := rawJSON - modelResult := gjson.GetBytes(rawJSON, "model") - modelName := modelResult.String() - cliCtx, cliCancel := context.WithCancel(context.Background()) - var cliClient *client.Client - defer func() { - if cliClient != nil { - cliClient.RequestMutex.Unlock() - } - }() - - for { - var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.GetClient(modelName, false) - if errorResponse != nil { - c.Status(errorResponse.StatusCode) - _, _ = fmt.Fprint(c.Writer, errorResponse.Error) - cliCancel() - return - } - - if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { - log.Debugf("Request use generative language API Key: %s", glAPIKey) - } else { - log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) - - template := `{"request":{}}` - if gjson.GetBytes(rawJSON, "generateContentRequest").Exists() { - template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJSON, "generateContentRequest").Raw) - template, _ = sjson.Delete(template, "generateContentRequest") - } else if gjson.GetBytes(rawJSON, "contents").Exists() { - template, _ = sjson.SetRaw(template, "request.contents", gjson.GetBytes(rawJSON, "contents").Raw) - template, _ = sjson.Delete(template, "contents") - } - rawJSON = []byte(template) - } - - resp, err := cliClient.SendRawTokenCount(cliCtx, rawJSON, alt) - if err != nil { - if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { - continue - } else { - c.Status(err.StatusCode) - _, _ = c.Writer.Write([]byte(err.Error.Error())) - cliCancel() - // log.Debugf(err.Error.Error()) - // log.Debugf(string(rawJSON)) - // log.Debugf(string(orgrawJSON)) - } - break - } else { - if cliClient.GetGenerativeLanguageAPIKey() == "" { - responseResult := gjson.GetBytes(resp, "response") - if responseResult.Exists() { - resp = []byte(responseResult.Raw) - } - } - _, _ = c.Writer.Write(resp) - cliCancel() - break - } - } -} - -func (h *GeminiAPIHandlers) geminiGenerateContent(c *gin.Context, rawJSON []byte) { - c.Header("Content-Type", "application/json") - - alt := h.GetAlt(c) - - modelResult := gjson.GetBytes(rawJSON, "model") - modelName := modelResult.String() - cliCtx, cliCancel := context.WithCancel(context.Background()) - var cliClient *client.Client - defer func() { - if cliClient != nil { - cliClient.RequestMutex.Unlock() - } - }() - - for { - var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.GetClient(modelName) - if errorResponse != nil { - c.Status(errorResponse.StatusCode) - _, _ = fmt.Fprint(c.Writer, errorResponse.Error) - cliCancel() - return - } - - template := "" - parsed := gjson.Parse(string(rawJSON)) - contents := parsed.Get("request.contents") - if contents.Exists() { - template = string(rawJSON) - } else { - template = `{"project":"","request":{},"model":""}` - template, _ = sjson.SetRaw(template, "request", string(rawJSON)) - template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String()) - template, _ = sjson.Delete(template, "request.model") - } - - template, errFixCLIToolResponse := cli.FixCLIToolResponse(template) - if errFixCLIToolResponse != nil { - c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: errFixCLIToolResponse.Error(), - Type: "server_error", - }, - }) - cliCancel() - return - } - - systemInstructionResult := gjson.Get(template, "request.system_instruction") - if systemInstructionResult.Exists() { - template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw) - template, _ = sjson.Delete(template, "request.system_instruction") - } - rawJSON = []byte(template) - - if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { - log.Debugf("Request use generative language API Key: %s", glAPIKey) - } else { - log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) - } - resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, alt) - if err != nil { - if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { - continue - } else { - c.Status(err.StatusCode) - _, _ = c.Writer.Write([]byte(err.Error.Error())) - cliCancel() - } - break - } else { - if cliClient.GetGenerativeLanguageAPIKey() == "" { - responseResult := gjson.GetBytes(resp, "response") - if responseResult.Exists() { - resp = []byte(responseResult.Raw) - } - } - _, _ = c.Writer.Write(resp) - cliCancel() - break - } - } -} diff --git a/internal/api/handlers/gemini/gemini_handlers.go b/internal/api/handlers/gemini/gemini_handlers.go new file mode 100644 index 00000000..a2e10556 --- /dev/null +++ b/internal/api/handlers/gemini/gemini_handlers.go @@ -0,0 +1,735 @@ +// Package gemini provides HTTP handlers for Gemini API endpoints. +// This package implements handlers for managing Gemini model operations including +// model listing, content generation, streaming content generation, and token counting. +// It serves as a proxy layer between clients and the Gemini backend service, +// handling request translation, client management, and response processing. +package gemini + +import ( + "bytes" + "context" + "fmt" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/luispater/CLIProxyAPI/internal/api/handlers" + "github.com/luispater/CLIProxyAPI/internal/client" + translatorGeminiToCodex "github.com/luispater/CLIProxyAPI/internal/translator/codex/gemini" + translatorGeminiToGeminiCli "github.com/luispater/CLIProxyAPI/internal/translator/gemini-cli/gemini/cli" + "github.com/luispater/CLIProxyAPI/internal/util" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// GeminiAPIHandlers contains the handlers for Gemini API endpoints. +// It holds a pool of clients to interact with the backend service. +type GeminiAPIHandlers struct { + *handlers.APIHandlers +} + +// NewGeminiAPIHandlers creates a new Gemini API handlers instance. +// It takes an APIHandlers instance as input and returns a GeminiAPIHandlers. +func NewGeminiAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiAPIHandlers { + return &GeminiAPIHandlers{ + APIHandlers: apiHandlers, + } +} + +// GeminiModels handles the Gemini models listing endpoint. +// It returns a JSON response containing available Gemini models and their specifications. +func (h *GeminiAPIHandlers) GeminiModels(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "data": []map[string]any{ + { + "id": "gemini-2.5-flash", + "object": "model", + "version": "001", + "name": "Gemini 2.5 Flash", + "description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.", + "context_length": 1_048_576, + "max_completion_tokens": 65_536, + "supported_parameters": []string{ + "tools", + "temperature", + "top_p", + "top_k", + }, + "temperature": 1, + "topP": 0.95, + "topK": 64, + "maxTemperature": 2, + "thinking": true, + }, + { + "id": "gemini-2.5-pro", + "object": "model", + "version": "2.5", + "name": "Gemini 2.5 Pro", + "description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro", + "context_length": 1_048_576, + "max_completion_tokens": 65_536, + "supported_parameters": []string{ + "tools", + "temperature", + "top_p", + "top_k", + }, + "temperature": 1, + "topP": 0.95, + "topK": 64, + "maxTemperature": 2, + "thinking": true, + }, + { + "id": "gpt-5", + "object": "model", + "version": "gpt-5-2025-08-07", + "name": "GPT 5", + "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + "context_length": 400_000, + "max_completion_tokens": 128_000, + "supported_parameters": []string{ + "tools", + }, + "temperature": 1, + "topP": 0.95, + "topK": 64, + "maxTemperature": 2, + "thinking": true, + }, + }, + }) +} + +// GeminiGetHandler handles GET requests for specific Gemini model information. +// It returns detailed information about a specific Gemini model based on the action parameter. +func (h *GeminiAPIHandlers) GeminiGetHandler(c *gin.Context) { + var request struct { + Action string `uri:"action" binding:"required"` + } + if err := c.ShouldBindUri(&request); err != nil { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + switch request.Action { + case "gemini-2.5-pro": + c.JSON(http.StatusOK, gin.H{ + "id": "gemini-2.5-pro", + "object": "model", + "version": "2.5", + "name": "Gemini 2.5 Pro", + "description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro", + "context_length": 1_048_576, + "max_completion_tokens": 65_536, + "supported_parameters": []string{ + "tools", + "temperature", + "top_p", + "top_k", + }, + "temperature": 1, + "topP": 0.95, + "topK": 64, + "maxTemperature": 2, + "thinking": true, + }) + case "gemini-2.5-flash": + c.JSON(http.StatusOK, gin.H{ + "id": "gemini-2.5-flash", + "object": "model", + "version": "001", + "name": "Gemini 2.5 Flash", + "description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.", + "context_length": 1_048_576, + "max_completion_tokens": 65_536, + "supported_parameters": []string{ + "tools", + "temperature", + "top_p", + "top_k", + }, + "temperature": 1, + "topP": 0.95, + "topK": 64, + "maxTemperature": 2, + "thinking": true, + }) + case "gpt-5": + c.JSON(http.StatusOK, gin.H{ + "id": "gpt-5", + "object": "model", + "version": "gpt-5-2025-08-07", + "name": "GPT 5", + "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + "context_length": 400_000, + "max_completion_tokens": 128_000, + "supported_parameters": []string{ + "tools", + }, + "temperature": 1, + "topP": 0.95, + "topK": 64, + "maxTemperature": 2, + "thinking": true, + }) + default: + c.JSON(http.StatusNotFound, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Not Found", + Type: "not_found", + }, + }) + } +} + +// GeminiHandler handles POST requests for Gemini API operations. +// It routes requests to appropriate handlers based on the action parameter (model:method format). +func (h *GeminiAPIHandlers) GeminiHandler(c *gin.Context) { + var request struct { + Action string `uri:"action" binding:"required"` + } + if err := c.ShouldBindUri(&request); err != nil { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + action := strings.Split(request.Action, ":") + if len(action) != 2 { + c.JSON(http.StatusNotFound, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("%s not found.", c.Request.URL.Path), + Type: "invalid_request_error", + }, + }) + return + } + + modelName := action[0] + method := action[1] + rawJSON, _ := c.GetRawData() + rawJSON, _ = sjson.SetBytes(rawJSON, "model", []byte(modelName)) + + provider := util.GetProviderName(modelName) + if provider == "gemini" || provider == "unknow" { + switch method { + case "generateContent": + h.handleGeminiGenerateContent(c, rawJSON) + case "streamGenerateContent": + h.handleGeminiStreamGenerateContent(c, rawJSON) + case "countTokens": + h.handleGeminiCountTokens(c, rawJSON) + } + } else if provider == "gpt" { + switch method { + case "generateContent": + h.handleCodexGenerateContent(c, rawJSON) + case "streamGenerateContent": + h.handleCodexStreamGenerateContent(c, rawJSON) + } + + } +} + +func (h *GeminiAPIHandlers) handleGeminiStreamGenerateContent(c *gin.Context, rawJSON []byte) { + alt := h.GetAlt(c) + + if alt == "" { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + } + + // Get the http.Flusher interface to manually flush the response. + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + modelResult := gjson.GetBytes(rawJSON, "model") + modelName := modelResult.String() + + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient client.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + flusher.Flush() + cliCancel() + return + } + + template := "" + parsed := gjson.Parse(string(rawJSON)) + contents := parsed.Get("request.contents") + if contents.Exists() { + template = string(rawJSON) + } else { + template = `{"project":"","request":{},"model":""}` + template, _ = sjson.SetRaw(template, "request", string(rawJSON)) + template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String()) + template, _ = sjson.Delete(template, "request.model") + } + + template, errFixCLIToolResponse := translatorGeminiToGeminiCli.FixCLIToolResponse(template) + if errFixCLIToolResponse != nil { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: errFixCLIToolResponse.Error(), + Type: "server_error", + }, + }) + cliCancel() + return + } + + systemInstructionResult := gjson.Get(template, "request.system_instruction") + if systemInstructionResult.Exists() { + template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw) + template, _ = sjson.Delete(template, "request.system_instruction") + } + rawJSON = []byte(template) + + if glAPIKey := cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use generative language API Key: %s", glAPIKey) + } else { + log.Debugf("Request cli use account: %s, project id: %s", cliClient.(*client.GeminiClient).GetEmail(), cliClient.(*client.GeminiClient).GetProjectID()) + } + + // Send the message and receive response chunks and errors via channels. + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, alt) + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("GeminiClient disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + cliCancel() + return + } + if cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey() == "" { + if alt == "" { + responseResult := gjson.GetBytes(chunk, "response") + if responseResult.Exists() { + chunk = []byte(responseResult.Raw) + } + } else { + chunkTemplate := "[]" + responseResult := gjson.ParseBytes(chunk) + if responseResult.IsArray() { + responseResultItems := responseResult.Array() + for i := 0; i < len(responseResultItems); i++ { + responseResultItem := responseResultItems[i] + if responseResultItem.Get("response").Exists() { + chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw) + } + } + } + chunk = []byte(chunkTemplate) + } + } + if alt == "" { + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write(chunk) + _, _ = c.Writer.Write([]byte("\n\n")) + } else { + _, _ = c.Writer.Write(chunk) + } + flusher.Flush() + // Handle errors from the backend. + case err, okError := <-errChan: + if okError { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + log.Debugf("quota exceeded, switch client") + continue outLoop + } else { + log.Debugf("error code :%d, error: %v", err.StatusCode, err.Error.Error()) + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + flusher.Flush() + cliCancel() + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + } + } + } +} + +func (h *GeminiAPIHandlers) handleGeminiCountTokens(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "application/json") + + alt := h.GetAlt(c) + // orgrawJSON := rawJSON + modelResult := gjson.GetBytes(rawJSON, "model") + modelName := modelResult.String() + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient client.Client + defer func() { + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName, false) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + cliCancel() + return + } + + if glAPIKey := cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use generative language API Key: %s", glAPIKey) + } else { + log.Debugf("Request cli use account: %s, project id: %s", cliClient.(*client.GeminiClient).GetEmail(), cliClient.(*client.GeminiClient).GetProjectID()) + + template := `{"request":{}}` + if gjson.GetBytes(rawJSON, "generateContentRequest").Exists() { + template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJSON, "generateContentRequest").Raw) + template, _ = sjson.Delete(template, "generateContentRequest") + } else if gjson.GetBytes(rawJSON, "contents").Exists() { + template, _ = sjson.SetRaw(template, "request.contents", gjson.GetBytes(rawJSON, "contents").Raw) + template, _ = sjson.Delete(template, "contents") + } + rawJSON = []byte(template) + } + + resp, err := cliClient.SendRawTokenCount(cliCtx, rawJSON, alt) + if err != nil { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue + } else { + c.Status(err.StatusCode) + _, _ = c.Writer.Write([]byte(err.Error.Error())) + cliCancel() + // log.Debugf(err.Error.Error()) + // log.Debugf(string(rawJSON)) + // log.Debugf(string(orgrawJSON)) + } + break + } else { + if cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey() == "" { + responseResult := gjson.GetBytes(resp, "response") + if responseResult.Exists() { + resp = []byte(responseResult.Raw) + } + } + _, _ = c.Writer.Write(resp) + cliCancel() + break + } + } +} + +func (h *GeminiAPIHandlers) handleGeminiGenerateContent(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "application/json") + + alt := h.GetAlt(c) + + modelResult := gjson.GetBytes(rawJSON, "model") + modelName := modelResult.String() + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient client.Client + defer func() { + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + cliCancel() + return + } + + template := "" + parsed := gjson.Parse(string(rawJSON)) + contents := parsed.Get("request.contents") + if contents.Exists() { + template = string(rawJSON) + } else { + template = `{"project":"","request":{},"model":""}` + template, _ = sjson.SetRaw(template, "request", string(rawJSON)) + template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String()) + template, _ = sjson.Delete(template, "request.model") + } + + template, errFixCLIToolResponse := translatorGeminiToGeminiCli.FixCLIToolResponse(template) + if errFixCLIToolResponse != nil { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: errFixCLIToolResponse.Error(), + Type: "server_error", + }, + }) + cliCancel() + return + } + + systemInstructionResult := gjson.Get(template, "request.system_instruction") + if systemInstructionResult.Exists() { + template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw) + template, _ = sjson.Delete(template, "request.system_instruction") + } + rawJSON = []byte(template) + + if glAPIKey := cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use generative language API Key: %s", glAPIKey) + } else { + log.Debugf("Request cli use account: %s, project id: %s", cliClient.(*client.GeminiClient).GetEmail(), cliClient.(*client.GeminiClient).GetProjectID()) + } + resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, alt) + if err != nil { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue + } else { + c.Status(err.StatusCode) + _, _ = c.Writer.Write([]byte(err.Error.Error())) + cliCancel() + } + break + } else { + if cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey() == "" { + responseResult := gjson.GetBytes(resp, "response") + if responseResult.Exists() { + resp = []byte(responseResult.Raw) + } + } + _, _ = c.Writer.Write(resp) + cliCancel() + break + } + } +} + +func (h *GeminiAPIHandlers) handleCodexStreamGenerateContent(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + + // Get the http.Flusher interface to manually flush the response. + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + // Prepare the request for the backend client. + newRequestJSON := translatorGeminiToCodex.ConvertGeminiRequestToCodex(rawJSON) + // log.Debugf("Request: %s", newRequestJSON) + + modelName := gjson.GetBytes(rawJSON, "model") + + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient client.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName.String()) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + flusher.Flush() + cliCancel() + return + } + + log.Debugf("Request codex use account: %s", cliClient.GetEmail()) + + // Send the message and receive response chunks and errors via channels. + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "") + params := &translatorGeminiToCodex.ConvertCodexResponseToGeminiParams{ + Model: modelName.String(), + CreatedAt: 0, + ResponseID: "", + LastStorageOutput: "", + } + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + cliCancel() + return + } + + if bytes.HasPrefix(chunk, []byte("data: ")) { + jsonData := chunk[6:] + data := gjson.ParseBytes(jsonData) + typeResult := data.Get("type") + if typeResult.String() != "" { + outputs := translatorGeminiToCodex.ConvertCodexResponseToGemini(jsonData, params) + if len(outputs) > 0 { + for i := 0; i < len(outputs); i++ { + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write([]byte(outputs[i])) + _, _ = c.Writer.Write([]byte("\n\n")) + } + } + } + // log.Debugf(string(jsonData)) + } + flusher.Flush() + // Handle errors from the backend. + case err, okError := <-errChan: + if okError { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue outLoop + } else { + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + flusher.Flush() + cliCancel() + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + } + } + } +} + +func (h *GeminiAPIHandlers) handleCodexGenerateContent(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "application/json") + + // Prepare the request for the backend client. + newRequestJSON := translatorGeminiToCodex.ConvertGeminiRequestToCodex(rawJSON) + // log.Debugf("Request: %s", newRequestJSON) + + modelName := gjson.GetBytes(rawJSON, "model") + + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient client.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName.String()) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + cliCancel() + return + } + + log.Debugf("Request codex use account: %s", cliClient.GetEmail()) + + // Send the message and receive response chunks and errors via channels. + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "") + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + cliCancel() + return + } + + if bytes.HasPrefix(chunk, []byte("data: ")) { + jsonData := chunk[6:] + data := gjson.ParseBytes(jsonData) + typeResult := data.Get("type") + if typeResult.String() != "" { + var geminiStr string + geminiStr = translatorGeminiToCodex.ConvertCodexResponseToGeminiNonStream(jsonData, modelName.String()) + if geminiStr != "" { + _, _ = c.Writer.Write([]byte(geminiStr)) + } + } + } + // Handle errors from the backend. + case err, okError := <-errChan: + if okError { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue outLoop + } else { + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + cliCancel() + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + } + } + } +} diff --git a/internal/api/handlers/handlers.go b/internal/api/handlers/handlers.go index f6fe2326..fc80b168 100644 --- a/internal/api/handlers/handlers.go +++ b/internal/api/handlers/handlers.go @@ -5,52 +5,78 @@ package handlers import ( "fmt" + "sync" + "github.com/gin-gonic/gin" "github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/config" + "github.com/luispater/CLIProxyAPI/internal/util" log "github.com/sirupsen/logrus" - "sync" ) // ErrorResponse represents a standard error response format for the API. // It contains a single ErrorDetail field. type ErrorResponse struct { + // Error contains detailed information about the error that occurred. Error ErrorDetail `json:"error"` } // ErrorDetail provides specific information about an error that occurred. // It includes a human-readable message, an error type, and an optional error code. type ErrorDetail struct { - // A human-readable message providing more details about the error. + // Message is a human-readable message providing more details about the error. Message string `json:"message"` - // The type of error that occurred (e.g., "invalid_request_error"). + + // Type is the category of error that occurred (e.g., "invalid_request_error"). Type string `json:"type"` - // A short code identifying the error, if applicable. + + // Code is a short code identifying the error, if applicable. Code string `json:"code,omitempty"` } // APIHandlers contains the handlers for API endpoints. -// It holds a pool of clients to interact with the backend service. +// It holds a pool of clients to interact with the backend service and manages +// load balancing, client selection, and configuration. type APIHandlers struct { - CliClients []*client.Client - Cfg *config.Config - Mutex *sync.Mutex - LastUsedClientIndex int + // CliClients is the pool of available AI service clients. + CliClients []client.Client + + // Cfg holds the current application configuration. + Cfg *config.Config + + // Mutex ensures thread-safe access to shared resources. + Mutex *sync.Mutex + + // LastUsedClientIndex tracks the last used client index for each provider + // to implement round-robin load balancing. + LastUsedClientIndex map[string]int } // NewAPIHandlers creates a new API handlers instance. -// It takes a slice of clients and a debug flag as input. -func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers { +// It takes a slice of clients and configuration as input. +// +// Parameters: +// - cliClients: A slice of AI service clients +// - cfg: The application configuration +// +// Returns: +// - *APIHandlers: A new API handlers instance +func NewAPIHandlers(cliClients []client.Client, cfg *config.Config) *APIHandlers { return &APIHandlers{ CliClients: cliClients, Cfg: cfg, Mutex: &sync.Mutex{}, - LastUsedClientIndex: 0, + LastUsedClientIndex: make(map[string]int), } } -// UpdateClients updates the handlers' client list and configuration -func (h *APIHandlers) UpdateClients(clients []*client.Client, cfg *config.Config) { +// UpdateClients updates the handlers' client list and configuration. +// This method is called when the configuration or authentication tokens change. +// +// Parameters: +// - clients: The new slice of AI service clients +// - cfg: The new application configuration +func (h *APIHandlers) UpdateClients(clients []client.Client, cfg *config.Config) { h.CliClients = clients h.Cfg = cfg } @@ -58,30 +84,63 @@ func (h *APIHandlers) UpdateClients(clients []*client.Client, cfg *config.Config // GetClient returns an available client from the pool using round-robin load balancing. // It checks for quota limits and tries to find an unlocked client for immediate use. // The modelName parameter is used to check quota status for specific models. -func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (*client.Client, *client.ErrorMessage) { - if len(h.CliClients) == 0 { +// +// Parameters: +// - modelName: The name of the model to be used +// - isGenerateContent: Optional parameter to indicate if this is for content generation +// +// Returns: +// - client.Client: An available client for the requested model +// - *client.ErrorMessage: An error message if no client is available +func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (client.Client, *client.ErrorMessage) { + provider := util.GetProviderName(modelName) + clients := make([]client.Client, 0) + if provider == "gemini" { + for i := 0; i < len(h.CliClients); i++ { + if cli, ok := h.CliClients[i].(*client.GeminiClient); ok { + clients = append(clients, cli) + } + } + } else if provider == "gpt" { + for i := 0; i < len(h.CliClients); i++ { + if cli, ok := h.CliClients[i].(*client.CodexClient); ok { + clients = append(clients, cli) + } + } + } + + if _, hasKey := h.LastUsedClientIndex[provider]; !hasKey { + h.LastUsedClientIndex[provider] = 0 + } + + if len(clients) == 0 { return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")} } - var cliClient *client.Client + var cliClient client.Client // Lock the mutex to update the last used client index h.Mutex.Lock() - startIndex := h.LastUsedClientIndex + startIndex := h.LastUsedClientIndex[provider] if (len(isGenerateContent) > 0 && isGenerateContent[0]) || len(isGenerateContent) == 0 { - currentIndex := (startIndex + 1) % len(h.CliClients) - h.LastUsedClientIndex = currentIndex + currentIndex := (startIndex + 1) % len(clients) + h.LastUsedClientIndex[provider] = currentIndex } h.Mutex.Unlock() // Reorder the client to start from the last used index - reorderedClients := make([]*client.Client, 0) - for i := 0; i < len(h.CliClients); i++ { - cliClient = h.CliClients[(startIndex+1+i)%len(h.CliClients)] + reorderedClients := make([]client.Client, 0) + for i := 0; i < len(clients); i++ { + cliClient = clients[(startIndex+1+i)%len(clients)] if cliClient.IsModelQuotaExceeded(modelName) { - log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID()) + if provider == "gemini" { + log.Debugf("Gemini Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.(*client.GeminiClient).GetProjectID()) + } else if provider == "gpt" { + log.Debugf("Codex Model %s is quota exceeded for account %s", modelName, cliClient.GetEmail()) + } cliClient = nil continue + } reorderedClients = append(reorderedClients, cliClient) } @@ -93,14 +152,14 @@ func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (*c locked := false for i := 0; i < len(reorderedClients); i++ { cliClient = reorderedClients[i] - if cliClient.RequestMutex.TryLock() { + if cliClient.GetRequestMutex().TryLock() { locked = true break } } if !locked { - cliClient = h.CliClients[0] - cliClient.RequestMutex.Lock() + cliClient = clients[0] + cliClient.GetRequestMutex().Lock() } return cliClient, nil @@ -108,6 +167,12 @@ func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (*c // GetAlt extracts the 'alt' parameter from the request query string. // It checks both 'alt' and '$alt' parameters and returns the appropriate value. +// +// Parameters: +// - c: The Gin context containing the HTTP request +// +// Returns: +// - string: The alt parameter value, or empty string if it's "sse" func (h *APIHandlers) GetAlt(c *gin.Context) string { var alt string var hasAlt bool diff --git a/internal/api/handlers/openai/openai-handlers.go b/internal/api/handlers/openai/openai-handlers.go deleted file mode 100644 index 7623278e..00000000 --- a/internal/api/handlers/openai/openai-handlers.go +++ /dev/null @@ -1,264 +0,0 @@ -// Package openai provides HTTP handlers for OpenAI API endpoints. -// This package implements the OpenAI-compatible API interface, including model listing -// and chat completion functionality. It supports both streaming and non-streaming responses, -// and manages a pool of clients to interact with backend services. -// The handlers translate OpenAI API requests to the appropriate backend format and -// convert responses back to OpenAI-compatible format. -package openai - -import ( - "context" - "fmt" - "github.com/luispater/CLIProxyAPI/internal/api/handlers" - "github.com/luispater/CLIProxyAPI/internal/api/translator/openai" - "github.com/luispater/CLIProxyAPI/internal/client" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - "net/http" - "time" - - "github.com/gin-gonic/gin" -) - -// OpenAIAPIHandlers contains the handlers for OpenAI API endpoints. -// It holds a pool of clients to interact with the backend service. -type OpenAIAPIHandlers struct { - *handlers.APIHandlers -} - -// NewOpenAIAPIHandlers creates a new OpenAI API handlers instance. -// It takes an APIHandlers instance as input and returns an OpenAIAPIHandlers. -func NewOpenAIAPIHandlers(apiHandlers *handlers.APIHandlers) *OpenAIAPIHandlers { - return &OpenAIAPIHandlers{ - APIHandlers: apiHandlers, - } -} - -// Models handles the /v1/models endpoint. -// It returns a hardcoded list of available AI models. -func (h *OpenAIAPIHandlers) Models(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{ - "data": []map[string]any{ - { - "id": "gemini-2.5-pro", - "object": "model", - "version": "2.5", - "name": "Gemini 2.5 Pro", - "description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro", - "context_length": 1048576, - "max_completion_tokens": 65536, - "supported_parameters": []string{ - "tools", - "temperature", - "top_p", - "top_k", - }, - "temperature": 1, - "topP": 0.95, - "topK": 64, - "maxTemperature": 2, - "thinking": true, - }, - { - "id": "gemini-2.5-flash", - "object": "model", - "version": "001", - "name": "Gemini 2.5 Flash", - "description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.", - "context_length": 1048576, - "max_completion_tokens": 65536, - "supported_parameters": []string{ - "tools", - "temperature", - "top_p", - "top_k", - }, - "temperature": 1, - "topP": 0.95, - "topK": 64, - "maxTemperature": 2, - "thinking": true, - }, - }, - }) -} - -// ChatCompletions handles the /v1/chat/completions endpoint. -// It determines whether the request is for a streaming or non-streaming response -// and calls the appropriate handler. -func (h *OpenAIAPIHandlers) ChatCompletions(c *gin.Context) { - rawJSON, err := c.GetRawData() - // If data retrieval fails, return a 400 Bad Request error. - if err != nil { - c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: fmt.Sprintf("Invalid request: %v", err), - Type: "invalid_request_error", - }, - }) - return - } - - // Check if the client requested a streaming response. - streamResult := gjson.GetBytes(rawJSON, "stream") - if streamResult.Type == gjson.True { - h.handleStreamingResponse(c, rawJSON) - } else { - h.handleNonStreamingResponse(c, rawJSON) - } -} - -// handleNonStreamingResponse handles non-streaming chat completion responses. -// It selects a client from the pool, sends the request, and aggregates the response -// before sending it back to the client. -func (h *OpenAIAPIHandlers) handleNonStreamingResponse(c *gin.Context, rawJSON []byte) { - c.Header("Content-Type", "application/json") - - modelName, systemInstruction, contents, tools := openai.PrepareRequest(rawJSON) - cliCtx, cliCancel := context.WithCancel(context.Background()) - var cliClient *client.Client - defer func() { - if cliClient != nil { - cliClient.RequestMutex.Unlock() - } - }() - - for { - var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.GetClient(modelName) - if errorResponse != nil { - c.Status(errorResponse.StatusCode) - _, _ = fmt.Fprint(c.Writer, errorResponse.Error) - cliCancel() - return - } - - isGlAPIKey := false - if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { - log.Debugf("Request use generative language API Key: %s", glAPIKey) - isGlAPIKey = true - } else { - log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) - } - - resp, err := cliClient.SendMessage(cliCtx, rawJSON, modelName, systemInstruction, contents, tools) - if err != nil { - if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { - continue - } else { - c.Status(err.StatusCode) - _, _ = c.Writer.Write([]byte(err.Error.Error())) - cliCancel() - } - break - } else { - openAIFormat := openai.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey) - if openAIFormat != "" { - _, _ = c.Writer.Write([]byte(openAIFormat)) - } - cliCancel() - break - } - } -} - -// handleStreamingResponse handles streaming responses -func (h *OpenAIAPIHandlers) handleStreamingResponse(c *gin.Context, rawJSON []byte) { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") - - // Get the http.Flusher interface to manually flush the response. - flusher, ok := c.Writer.(http.Flusher) - if !ok { - c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ - Error: handlers.ErrorDetail{ - Message: "Streaming not supported", - Type: "server_error", - }, - }) - return - } - - // Prepare the request for the backend client. - modelName, systemInstruction, contents, tools := openai.PrepareRequest(rawJSON) - cliCtx, cliCancel := context.WithCancel(context.Background()) - var cliClient *client.Client - defer func() { - // Ensure the client's mutex is unlocked on function exit. - if cliClient != nil { - cliClient.RequestMutex.Unlock() - } - }() - -outLoop: - for { - var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.GetClient(modelName) - if errorResponse != nil { - c.Status(errorResponse.StatusCode) - _, _ = fmt.Fprint(c.Writer, errorResponse.Error) - flusher.Flush() - cliCancel() - return - } - - isGlAPIKey := false - if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { - log.Debugf("Request use generative language API Key: %s", glAPIKey) - isGlAPIKey = true - } else { - log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) - } - // Send the message and receive response chunks and errors via channels. - respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools) - hasFirstResponse := false - for { - select { - // Handle client disconnection. - case <-c.Request.Context().Done(): - if c.Request.Context().Err().Error() == "context canceled" { - log.Debugf("Client disconnected: %v", c.Request.Context().Err()) - cliCancel() // Cancel the backend request. - return - } - // Process incoming response chunks. - case chunk, okStream := <-respChan: - if !okStream { - // Stream is closed, send the final [DONE] message. - _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") - flusher.Flush() - cliCancel() - return - } - // Convert the chunk to OpenAI format and send it to the client. - hasFirstResponse = true - openAIFormat := openai.ConvertCliToOpenAI(chunk, time.Now().Unix(), isGlAPIKey) - if openAIFormat != "" { - _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat) - flusher.Flush() - } - // Handle errors from the backend. - case err, okError := <-errChan: - if okError { - if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { - continue outLoop - } else { - c.Status(err.StatusCode) - _, _ = fmt.Fprint(c.Writer, err.Error.Error()) - flusher.Flush() - cliCancel() - } - return - } - // Send a keep-alive signal to the client. - case <-time.After(500 * time.Millisecond): - if hasFirstResponse { - _, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n")) - flusher.Flush() - } - } - } - } -} diff --git a/internal/api/handlers/openai/openai_handlers.go b/internal/api/handlers/openai/openai_handlers.go new file mode 100644 index 00000000..2a2370bb --- /dev/null +++ b/internal/api/handlers/openai/openai_handlers.go @@ -0,0 +1,506 @@ +// Package openai provides HTTP handlers for OpenAI API endpoints. +// This package implements the OpenAI-compatible API interface, including model listing +// and chat completion functionality. It supports both streaming and non-streaming responses, +// and manages a pool of clients to interact with backend services. +// The handlers translate OpenAI API requests to the appropriate backend format and +// convert responses back to OpenAI-compatible format. +package openai + +import ( + "bytes" + "context" + "fmt" + "net/http" + "time" + + "github.com/luispater/CLIProxyAPI/internal/api/handlers" + "github.com/luispater/CLIProxyAPI/internal/client" + translatorOpenAIToCodex "github.com/luispater/CLIProxyAPI/internal/translator/codex/openai" + translatorOpenAIToGeminiCli "github.com/luispater/CLIProxyAPI/internal/translator/gemini-cli/openai" + "github.com/luispater/CLIProxyAPI/internal/util" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + + "github.com/gin-gonic/gin" +) + +// OpenAIAPIHandlers contains the handlers for OpenAI API endpoints. +// It holds a pool of clients to interact with the backend service. +type OpenAIAPIHandlers struct { + *handlers.APIHandlers +} + +// NewOpenAIAPIHandlers creates a new OpenAI API handlers instance. +// It takes an APIHandlers instance as input and returns an OpenAIAPIHandlers. +// +// Parameters: +// - apiHandlers: The base API handlers instance +// +// Returns: +// - *OpenAIAPIHandlers: A new OpenAI API handlers instance +func NewOpenAIAPIHandlers(apiHandlers *handlers.APIHandlers) *OpenAIAPIHandlers { + return &OpenAIAPIHandlers{ + APIHandlers: apiHandlers, + } +} + +// Models handles the /v1/models endpoint. +// It returns a hardcoded list of available AI models with their capabilities +// and specifications in OpenAI-compatible format. +func (h *OpenAIAPIHandlers) Models(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "data": []map[string]any{ + { + "id": "gemini-2.5-pro", + "object": "model", + "version": "2.5", + "name": "Gemini 2.5 Pro", + "description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro", + "context_length": 1_048_576, + "max_completion_tokens": 65_536, + "supported_parameters": []string{ + "tools", + "temperature", + "top_p", + "top_k", + }, + "temperature": 1, + "topP": 0.95, + "topK": 64, + "maxTemperature": 2, + "thinking": true, + }, + { + "id": "gemini-2.5-flash", + "object": "model", + "version": "001", + "name": "Gemini 2.5 Flash", + "description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.", + "context_length": 1_048_576, + "max_completion_tokens": 65_536, + "supported_parameters": []string{ + "tools", + "temperature", + "top_p", + "top_k", + }, + "temperature": 1, + "topP": 0.95, + "topK": 64, + "maxTemperature": 2, + "thinking": true, + }, + { + "id": "gpt-5", + "object": "model", + "version": "gpt-5-2025-08-07", + "name": "GPT 5", + "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + "context_length": 400_000, + "max_completion_tokens": 128_000, + "supported_parameters": []string{ + "tools", + }, + "temperature": 1, + "topP": 0.95, + "topK": 64, + "maxTemperature": 2, + "thinking": true, + }, + }, + }) +} + +// ChatCompletions handles the /v1/chat/completions endpoint. +// It determines whether the request is for a streaming or non-streaming response +// and calls the appropriate handler based on the model provider. +// +// Parameters: +// - c: The Gin context containing the HTTP request and response +func (h *OpenAIAPIHandlers) ChatCompletions(c *gin.Context) { + rawJSON, err := c.GetRawData() + // If data retrieval fails, return a 400 Bad Request error. + if err != nil { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + + // Check if the client requested a streaming response. + streamResult := gjson.GetBytes(rawJSON, "stream") + modelName := gjson.GetBytes(rawJSON, "model") + provider := util.GetProviderName(modelName.String()) + if provider == "gemini" { + if streamResult.Type == gjson.True { + h.handleGeminiStreamingResponse(c, rawJSON) + } else { + h.handleGeminiNonStreamingResponse(c, rawJSON) + } + } else if provider == "gpt" { + if streamResult.Type == gjson.True { + h.handleCodexStreamingResponse(c, rawJSON) + } else { + h.handleCodexNonStreamingResponse(c, rawJSON) + } + } +} + +// handleGeminiNonStreamingResponse handles non-streaming chat completion responses +// for Gemini models. It selects a client from the pool, sends the request, and +// aggregates the response before sending it back to the client in OpenAI format. +// +// Parameters: +// - c: The Gin context containing the HTTP request and response +// - rawJSON: The raw JSON bytes of the OpenAI-compatible request +func (h *OpenAIAPIHandlers) handleGeminiNonStreamingResponse(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "application/json") + + modelName, systemInstruction, contents, tools := translatorOpenAIToGeminiCli.ConvertOpenAIChatRequestToCli(rawJSON) + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient client.Client + defer func() { + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + cliCancel() + return + } + + isGlAPIKey := false + if glAPIKey := cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use generative language API Key: %s", glAPIKey) + isGlAPIKey = true + } else { + log.Debugf("Request cli use account: %s, project id: %s", cliClient.(*client.GeminiClient).GetEmail(), cliClient.(*client.GeminiClient).GetProjectID()) + } + + resp, err := cliClient.SendMessage(cliCtx, rawJSON, modelName, systemInstruction, contents, tools) + if err != nil { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue + } else { + c.Status(err.StatusCode) + _, _ = c.Writer.Write([]byte(err.Error.Error())) + cliCancel() + } + break + } else { + openAIFormat := translatorOpenAIToGeminiCli.ConvertCliResponseToOpenAIChatNonStream(resp, time.Now().Unix(), isGlAPIKey) + if openAIFormat != "" { + _, _ = c.Writer.Write([]byte(openAIFormat)) + } + cliCancel() + break + } + } +} + +// handleGeminiStreamingResponse handles streaming responses for Gemini models. +// It establishes a streaming connection with the backend service and forwards +// the response chunks to the client in real-time using Server-Sent Events. +// +// Parameters: +// - c: The Gin context containing the HTTP request and response +// - rawJSON: The raw JSON bytes of the OpenAI-compatible request +func (h *OpenAIAPIHandlers) handleGeminiStreamingResponse(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + + // Get the http.Flusher interface to manually flush the response. + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + // Prepare the request for the backend client. + modelName, systemInstruction, contents, tools := translatorOpenAIToGeminiCli.ConvertOpenAIChatRequestToCli(rawJSON) + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient client.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + flusher.Flush() + cliCancel() + return + } + + isGlAPIKey := false + if glAPIKey := cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use generative language API Key: %s", glAPIKey) + isGlAPIKey = true + } else { + log.Debugf("Request cli use account: %s, project id: %s", cliClient.GetEmail(), cliClient.(*client.GeminiClient).GetProjectID()) + } + // Send the message and receive response chunks and errors via channels. + respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools) + hasFirstResponse := false + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("GeminiClient disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + // Stream is closed, send the final [DONE] message. + _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") + flusher.Flush() + cliCancel() + return + } + // Convert the chunk to OpenAI format and send it to the client. + hasFirstResponse = true + openAIFormat := translatorOpenAIToGeminiCli.ConvertCliResponseToOpenAIChat(chunk, time.Now().Unix(), isGlAPIKey) + if openAIFormat != "" { + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat) + flusher.Flush() + } + // Handle errors from the backend. + case err, okError := <-errChan: + if okError { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue outLoop + } else { + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + flusher.Flush() + cliCancel() + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + if hasFirstResponse { + _, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n")) + flusher.Flush() + } + } + } + } +} + +// handleCodexNonStreamingResponse handles non-streaming chat completion responses +// for OpenAI models. It selects a client from the pool, sends the request, and +// aggregates the response before sending it back to the client in OpenAI format. +// +// Parameters: +// - c: The Gin context containing the HTTP request and response +// - rawJSON: The raw JSON bytes of the OpenAI-compatible request +func (h *OpenAIAPIHandlers) handleCodexNonStreamingResponse(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "application/json") + + newRequestJSON := translatorOpenAIToCodex.ConvertOpenAIChatRequestToCodex(rawJSON) + modelName := gjson.GetBytes(rawJSON, "model") + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient client.Client + defer func() { + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName.String()) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = c.Writer.Write([]byte(errorResponse.Error.Error())) + cliCancel() + return + } + + log.Debugf("Request codex use account: %s", cliClient.GetEmail()) + + // Send the message and receive response chunks and errors via channels. + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "") + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + cliCancel() + return + } + if bytes.HasPrefix(chunk, []byte("data: ")) { + jsonData := chunk[6:] + data := gjson.ParseBytes(jsonData) + typeResult := data.Get("type") + if typeResult.String() == "response.completed" { + responseResult := data.Get("response") + openaiStr := translatorOpenAIToCodex.ConvertCodexResponseToOpenAIChatNonStream(responseResult.Raw, time.Now().Unix()) + _, _ = c.Writer.Write([]byte(openaiStr)) + } + } + // Handle errors from the backend. + case err, okError := <-errChan: + if okError { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue outLoop + } else { + c.Status(err.StatusCode) + _, _ = c.Writer.Write([]byte(err.Error.Error())) + cliCancel() + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + } + } + } +} + +// handleCodexStreamingResponse handles streaming responses for OpenAI models. +// It establishes a streaming connection with the backend service and forwards +// the response chunks to the client in real-time using Server-Sent Events. +// +// Parameters: +// - c: The Gin context containing the HTTP request and response +// - rawJSON: The raw JSON bytes of the OpenAI-compatible request +func (h *OpenAIAPIHandlers) handleCodexStreamingResponse(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + + // Get the http.Flusher interface to manually flush the response. + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + // Prepare the request for the backend client. + newRequestJSON := translatorOpenAIToCodex.ConvertOpenAIChatRequestToCodex(rawJSON) + // log.Debugf("Request: %s", newRequestJSON) + + modelName := gjson.GetBytes(rawJSON, "model") + + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient client.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName.String()) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + flusher.Flush() + cliCancel() + return + } + + log.Debugf("Request codex use account: %s", cliClient.GetEmail()) + + // Send the message and receive response chunks and errors via channels. + var params *translatorOpenAIToCodex.ConvertCliToOpenAIParams + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "") + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + _, _ = c.Writer.Write([]byte("[done]\n\n")) + flusher.Flush() + cliCancel() + return + } + // log.Debugf("Response: %s\n", string(chunk)) + // Convert the chunk to OpenAI format and send it to the client. + if bytes.HasPrefix(chunk, []byte("data: ")) { + jsonData := chunk[6:] + data := gjson.ParseBytes(jsonData) + typeResult := data.Get("type") + if typeResult.String() != "" { + var openaiStr string + params, openaiStr = translatorOpenAIToCodex.ConvertCodexResponseToOpenAIChat(jsonData, params) + if openaiStr != "" { + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write([]byte(openaiStr)) + _, _ = c.Writer.Write([]byte("\n\n")) + } + } + // log.Debugf(string(jsonData)) + } + flusher.Flush() + // Handle errors from the backend. + case err, okError := <-errChan: + if okError { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue outLoop + } else { + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + flusher.Flush() + cliCancel() + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + } + } + } +} diff --git a/internal/api/middleware/request_logging.go b/internal/api/middleware/request_logging.go new file mode 100644 index 00000000..4ec69ee2 --- /dev/null +++ b/internal/api/middleware/request_logging.go @@ -0,0 +1,88 @@ +// Package middleware provides HTTP middleware components for the CLI Proxy API server. +// This file contains the request logging middleware that captures comprehensive +// request and response data when enabled through configuration. +package middleware + +import ( + "bytes" + "io" + + "github.com/gin-gonic/gin" + "github.com/luispater/CLIProxyAPI/internal/logging" +) + +// RequestLoggingMiddleware creates a Gin middleware function that logs HTTP requests and responses +// when enabled through the provided logger. The middleware has zero overhead when logging is disabled. +func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc { + return func(c *gin.Context) { + // Early return if logging is disabled (zero overhead) + if !logger.IsEnabled() { + c.Next() + return + } + + // Capture request information + requestInfo, err := captureRequestInfo(c) + if err != nil { + // Log error but continue processing + // In a real implementation, you might want to use a proper logger here + c.Next() + return + } + + // Create response writer wrapper + wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo) + c.Writer = wrapper + + // Process the request + c.Next() + + // Finalize logging after request processing + if err := wrapper.Finalize(); err != nil { + // Log error but don't interrupt the response + // In a real implementation, you might want to use a proper logger here + } + } +} + +// captureRequestInfo extracts and captures request information for logging. +func captureRequestInfo(c *gin.Context) (*RequestInfo, error) { + // Capture URL + url := c.Request.URL.String() + if c.Request.URL.Path != "" { + url = c.Request.URL.Path + if c.Request.URL.RawQuery != "" { + url += "?" + c.Request.URL.RawQuery + } + } + + // Capture method + method := c.Request.Method + + // Capture headers + headers := make(map[string][]string) + for key, values := range c.Request.Header { + headers[key] = values + } + + // Capture request body + var body []byte + if c.Request.Body != nil { + // Read the body + bodyBytes, err := io.ReadAll(c.Request.Body) + if err != nil { + return nil, err + } + + // Restore the body for the actual request processing + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + body = bodyBytes + } + + return &RequestInfo{ + URL: url, + Method: method, + Headers: headers, + Body: body, + }, nil +} diff --git a/internal/api/middleware/response_writer.go b/internal/api/middleware/response_writer.go new file mode 100644 index 00000000..d4897855 --- /dev/null +++ b/internal/api/middleware/response_writer.go @@ -0,0 +1,208 @@ +// Package middleware provides HTTP middleware components for the CLI Proxy API server. +// This includes request logging middleware and response writer wrappers that capture +// request and response data for logging purposes while maintaining zero-latency performance. +package middleware + +import ( + "bytes" + "strings" + + "github.com/gin-gonic/gin" + "github.com/luispater/CLIProxyAPI/internal/logging" +) + +// RequestInfo holds information about the current request for logging purposes. +type RequestInfo struct { + URL string + Method string + Headers map[string][]string + Body []byte +} + +// ResponseWriterWrapper wraps gin.ResponseWriter to capture response data for logging. +// It maintains zero-latency performance by prioritizing client response over logging operations. +type ResponseWriterWrapper struct { + gin.ResponseWriter + body *bytes.Buffer + isStreaming bool + streamWriter logging.StreamingLogWriter + chunkChannel chan []byte + logger logging.RequestLogger + requestInfo *RequestInfo + statusCode int + headers map[string][]string +} + +// NewResponseWriterWrapper creates a new response writer wrapper. +func NewResponseWriterWrapper(w gin.ResponseWriter, logger logging.RequestLogger, requestInfo *RequestInfo) *ResponseWriterWrapper { + return &ResponseWriterWrapper{ + ResponseWriter: w, + body: &bytes.Buffer{}, + logger: logger, + requestInfo: requestInfo, + headers: make(map[string][]string), + } +} + +// Write intercepts response data while maintaining normal Gin functionality. +// CRITICAL: This method prioritizes client response (zero-latency) over logging operations. +func (w *ResponseWriterWrapper) Write(data []byte) (int, error) { + // CRITICAL: Write to client first (zero latency) + n, err := w.ResponseWriter.Write(data) + + // THEN: Handle logging based on response type + if w.isStreaming { + // For streaming responses: Send to async logging channel (non-blocking) + if w.chunkChannel != nil { + select { + case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy + default: // Channel full, skip logging to avoid blocking + } + } + } else { + // For non-streaming responses: Buffer complete response + w.body.Write(data) + } + + return n, err +} + +// WriteHeader captures the status code and detects streaming responses. +func (w *ResponseWriterWrapper) WriteHeader(statusCode int) { + w.statusCode = statusCode + + // Capture response headers + for key, values := range w.ResponseWriter.Header() { + w.headers[key] = values + } + + // Detect streaming based on Content-Type + contentType := w.ResponseWriter.Header().Get("Content-Type") + w.isStreaming = w.detectStreaming(contentType) + + // If streaming, initialize streaming log writer + if w.isStreaming && w.logger.IsEnabled() { + streamWriter, err := w.logger.LogStreamingRequest( + w.requestInfo.URL, + w.requestInfo.Method, + w.requestInfo.Headers, + w.requestInfo.Body, + ) + if err == nil { + w.streamWriter = streamWriter + w.chunkChannel = make(chan []byte, 100) // Buffered channel for async writes + + // Start async chunk processor + go w.processStreamingChunks() + + // Write status immediately + _ = streamWriter.WriteStatus(statusCode, w.headers) + } + } + + // Call original WriteHeader + w.ResponseWriter.WriteHeader(statusCode) +} + +// detectStreaming determines if the response is streaming based on Content-Type and request analysis. +func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool { + // Check Content-Type for Server-Sent Events + if strings.Contains(contentType, "text/event-stream") { + return true + } + + // Check request body for streaming indicators + if w.requestInfo.Body != nil { + bodyStr := string(w.requestInfo.Body) + if strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`) { + return true + } + } + + return false +} + +// processStreamingChunks handles async processing of streaming chunks. +func (w *ResponseWriterWrapper) processStreamingChunks() { + if w.streamWriter == nil || w.chunkChannel == nil { + return + } + + for chunk := range w.chunkChannel { + w.streamWriter.WriteChunkAsync(chunk) + } +} + +// Finalize completes the logging process for the response. +func (w *ResponseWriterWrapper) Finalize() error { + if !w.logger.IsEnabled() { + return nil + } + + if w.isStreaming { + // Close streaming channel and writer + if w.chunkChannel != nil { + close(w.chunkChannel) + w.chunkChannel = nil + } + + if w.streamWriter != nil { + return w.streamWriter.Close() + } + } else { + // Capture final status code and headers if not already captured + finalStatusCode := w.statusCode + if finalStatusCode == 0 { + // Get status from underlying ResponseWriter if available + if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok { + finalStatusCode = statusWriter.Status() + } else { + finalStatusCode = 200 // Default + } + } + + // Capture final headers + finalHeaders := make(map[string][]string) + for key, values := range w.ResponseWriter.Header() { + finalHeaders[key] = values + } + // Merge with any headers we captured earlier + for key, values := range w.headers { + finalHeaders[key] = values + } + + // Log complete non-streaming response + return w.logger.LogRequest( + w.requestInfo.URL, + w.requestInfo.Method, + w.requestInfo.Headers, + w.requestInfo.Body, + finalStatusCode, + finalHeaders, + w.body.Bytes(), + ) + } + + return nil +} + +// Status returns the HTTP status code of the response. +func (w *ResponseWriterWrapper) Status() int { + if w.statusCode == 0 { + return 200 // Default status code + } + return w.statusCode +} + +// Size returns the size of the response body. +func (w *ResponseWriterWrapper) Size() int { + if w.isStreaming { + return -1 // Unknown size for streaming responses + } + return w.body.Len() +} + +// Written returns whether the response has been written. +func (w *ResponseWriterWrapper) Written() bool { + return w.statusCode != 0 +} diff --git a/internal/api/server.go b/internal/api/server.go index 612151b8..9d2791e1 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -8,31 +8,48 @@ import ( "context" "errors" "fmt" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/luispater/CLIProxyAPI/internal/api/handlers" "github.com/luispater/CLIProxyAPI/internal/api/handlers/claude" "github.com/luispater/CLIProxyAPI/internal/api/handlers/gemini" "github.com/luispater/CLIProxyAPI/internal/api/handlers/gemini/cli" "github.com/luispater/CLIProxyAPI/internal/api/handlers/openai" + "github.com/luispater/CLIProxyAPI/internal/api/middleware" "github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/config" + "github.com/luispater/CLIProxyAPI/internal/logging" log "github.com/sirupsen/logrus" - "net/http" - "strings" ) // Server represents the main API server. // It encapsulates the Gin engine, HTTP server, handlers, and configuration. type Server struct { - engine *gin.Engine - server *http.Server + // engine is the Gin web framework engine instance. + engine *gin.Engine + + // server is the underlying HTTP server. + server *http.Server + + // handlers contains the API handlers for processing requests. handlers *handlers.APIHandlers - cfg *config.Config + + // cfg holds the current server configuration. + cfg *config.Config } // NewServer creates and initializes a new API server instance. // It sets up the Gin engine, middleware, routes, and handlers. -func NewServer(cfg *config.Config, cliClients []*client.Client) *Server { +// +// Parameters: +// - cfg: The server configuration +// - cliClients: A slice of AI service clients +// +// Returns: +// - *Server: A new server instance +func NewServer(cfg *config.Config, cliClients []client.Client) *Server { // Set gin mode if !cfg.Debug { gin.SetMode(gin.ReleaseMode) @@ -44,6 +61,11 @@ func NewServer(cfg *config.Config, cliClients []*client.Client) *Server { // Add middleware engine.Use(gin.Logger()) engine.Use(gin.Recovery()) + + // Add request logging middleware (positioned after recovery, before auth) + requestLogger := logging.NewFileRequestLogger(cfg.RequestLog, "logs") + engine.Use(middleware.RequestLoggingMiddleware(requestLogger)) + engine.Use(corsMiddleware()) // Create server instance @@ -103,11 +125,13 @@ func (s *Server) setupRoutes() { }) }) s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler) - } // Start begins listening for and serving HTTP requests. // It's a blocking call and will only return on an unrecoverable error. +// +// Returns: +// - error: An error if the server fails to start func (s *Server) Start() error { log.Debugf("Starting API server on %s", s.server.Addr) @@ -121,6 +145,12 @@ func (s *Server) Start() error { // Stop gracefully shuts down the API server without interrupting any // active connections. +// +// Parameters: +// - ctx: The context for graceful shutdown +// +// Returns: +// - error: An error if the server fails to stop func (s *Server) Stop(ctx context.Context) error { log.Debug("Stopping API server...") @@ -135,6 +165,9 @@ func (s *Server) Stop(ctx context.Context) error { // corsMiddleware returns a Gin middleware handler that adds CORS headers // to every response, allowing cross-origin requests. +// +// Returns: +// - gin.HandlerFunc: The CORS middleware handler func corsMiddleware() gin.HandlerFunc { return func(c *gin.Context) { c.Header("Access-Control-Allow-Origin", "*") @@ -150,8 +183,13 @@ func corsMiddleware() gin.HandlerFunc { } } -// UpdateClients updates the server's client list and configuration -func (s *Server) UpdateClients(clients []*client.Client, cfg *config.Config) { +// UpdateClients updates the server's client list and configuration. +// This method is called when the configuration or authentication tokens change. +// +// Parameters: +// - clients: The new slice of AI service clients +// - cfg: The new application configuration +func (s *Server) UpdateClients(clients []client.Client, cfg *config.Config) { s.cfg = cfg s.handlers.UpdateClients(clients, cfg) log.Infof("server clients and configuration updated: %d clients", len(clients)) @@ -159,6 +197,12 @@ func (s *Server) UpdateClients(clients []*client.Client, cfg *config.Config) { // AuthMiddleware returns a Gin middleware handler that authenticates requests // using API keys. If no API keys are configured, it allows all requests. +// +// Parameters: +// - cfg: The server configuration containing API keys +// +// Returns: +// - gin.HandlerFunc: The authentication middleware handler func AuthMiddleware(cfg *config.Config) gin.HandlerFunc { return func(c *gin.Context) { if len(cfg.APIKeys) == 0 { diff --git a/internal/auth/codex/errors.go b/internal/auth/codex/errors.go new file mode 100644 index 00000000..55df5e04 --- /dev/null +++ b/internal/auth/codex/errors.go @@ -0,0 +1,155 @@ +package codex + +import ( + "errors" + "fmt" + "net/http" +) + +// OAuthError represents an OAuth-specific error +type OAuthError struct { + Code string `json:"error"` + Description string `json:"error_description,omitempty"` + URI string `json:"error_uri,omitempty"` + StatusCode int `json:"-"` +} + +func (e *OAuthError) Error() string { + if e.Description != "" { + return fmt.Sprintf("OAuth error %s: %s", e.Code, e.Description) + } + return fmt.Sprintf("OAuth error: %s", e.Code) +} + +// NewOAuthError creates a new OAuth error +func NewOAuthError(code, description string, statusCode int) *OAuthError { + return &OAuthError{ + Code: code, + Description: description, + StatusCode: statusCode, + } +} + +// AuthenticationError represents authentication-related errors +type AuthenticationError struct { + Type string `json:"type"` + Message string `json:"message"` + Code int `json:"code"` + Cause error `json:"-"` +} + +func (e *AuthenticationError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("%s: %s (caused by: %v)", e.Type, e.Message, e.Cause) + } + return fmt.Sprintf("%s: %s", e.Type, e.Message) +} + +// Common authentication error types +var ( + ErrTokenExpired = &AuthenticationError{ + Type: "token_expired", + Message: "Access token has expired", + Code: http.StatusUnauthorized, + } + + ErrInvalidState = &AuthenticationError{ + Type: "invalid_state", + Message: "OAuth state parameter is invalid", + Code: http.StatusBadRequest, + } + + ErrCodeExchangeFailed = &AuthenticationError{ + Type: "code_exchange_failed", + Message: "Failed to exchange authorization code for tokens", + Code: http.StatusBadRequest, + } + + ErrServerStartFailed = &AuthenticationError{ + Type: "server_start_failed", + Message: "Failed to start OAuth callback server", + Code: http.StatusInternalServerError, + } + + ErrPortInUse = &AuthenticationError{ + Type: "port_in_use", + Message: "OAuth callback port is already in use", + Code: 13, // Special exit code for port-in-use + } + + ErrCallbackTimeout = &AuthenticationError{ + Type: "callback_timeout", + Message: "Timeout waiting for OAuth callback", + Code: http.StatusRequestTimeout, + } + + ErrBrowserOpenFailed = &AuthenticationError{ + Type: "browser_open_failed", + Message: "Failed to open browser for authentication", + Code: http.StatusInternalServerError, + } +) + +// NewAuthenticationError creates a new authentication error with a cause +func NewAuthenticationError(baseErr *AuthenticationError, cause error) *AuthenticationError { + return &AuthenticationError{ + Type: baseErr.Type, + Message: baseErr.Message, + Code: baseErr.Code, + Cause: cause, + } +} + +// IsAuthenticationError checks if an error is an authentication error +func IsAuthenticationError(err error) bool { + var authenticationError *AuthenticationError + ok := errors.As(err, &authenticationError) + return ok +} + +// IsOAuthError checks if an error is an OAuth error +func IsOAuthError(err error) bool { + var oAuthError *OAuthError + ok := errors.As(err, &oAuthError) + return ok +} + +// GetUserFriendlyMessage returns a user-friendly error message +func GetUserFriendlyMessage(err error) string { + switch { + case IsAuthenticationError(err): + var authErr *AuthenticationError + errors.As(err, &authErr) + switch authErr.Type { + case "token_expired": + return "Your authentication has expired. Please log in again." + case "token_invalid": + return "Your authentication is invalid. Please log in again." + case "authentication_required": + return "Please log in to continue." + case "port_in_use": + return "The required port is already in use. Please close any applications using port 3000 and try again." + case "callback_timeout": + return "Authentication timed out. Please try again." + case "browser_open_failed": + return "Could not open your browser automatically. Please copy and paste the URL manually." + default: + return "Authentication failed. Please try again." + } + case IsOAuthError(err): + var oauthErr *OAuthError + errors.As(err, &oauthErr) + switch oauthErr.Code { + case "access_denied": + return "Authentication was cancelled or denied." + case "invalid_request": + return "Invalid authentication request. Please try again." + case "server_error": + return "Authentication server error. Please try again later." + default: + return fmt.Sprintf("Authentication failed: %s", oauthErr.Description) + } + default: + return "An unexpected error occurred. Please try again." + } +} diff --git a/internal/auth/codex/html_templates.go b/internal/auth/codex/html_templates.go new file mode 100644 index 00000000..9be62b5d --- /dev/null +++ b/internal/auth/codex/html_templates.go @@ -0,0 +1,210 @@ +package codex + +// LoginSuccessHtml is the template for the OAuth success page +const LoginSuccessHtml = ` + + + + + Authentication Successful - Codex + + + + +
+
+

Authentication Successful!

+

You have successfully authenticated with Codex. You can now close this window and return to your terminal to continue.

+ + {{SETUP_NOTICE}} + +
+ + + Open Platform + + +
+ +
+ This window will close automatically in 10 seconds +
+ + +
+ + + +` + +// SetupNoticeHtml is the template for the setup notice section +const SetupNoticeHtml = ` +
+

Additional Setup Required

+

To complete your setup, please visit the Codex to configure your account.

+
` diff --git a/internal/auth/codex/jwt_parser.go b/internal/auth/codex/jwt_parser.go new file mode 100644 index 00000000..6302cca7 --- /dev/null +++ b/internal/auth/codex/jwt_parser.go @@ -0,0 +1,89 @@ +package codex + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "time" +) + +// JWTClaims represents the claims section of a JWT token +type JWTClaims struct { + AtHash string `json:"at_hash"` + Aud []string `json:"aud"` + AuthProvider string `json:"auth_provider"` + AuthTime int `json:"auth_time"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Exp int `json:"exp"` + CodexAuthInfo CodexAuthInfo `json:"https://api.openai.com/auth"` + Iat int `json:"iat"` + Iss string `json:"iss"` + Jti string `json:"jti"` + Rat int `json:"rat"` + Sid string `json:"sid"` + Sub string `json:"sub"` +} +type Organizations struct { + ID string `json:"id"` + IsDefault bool `json:"is_default"` + Role string `json:"role"` + Title string `json:"title"` +} +type CodexAuthInfo struct { + ChatgptAccountID string `json:"chatgpt_account_id"` + ChatgptPlanType string `json:"chatgpt_plan_type"` + ChatgptSubscriptionActiveStart any `json:"chatgpt_subscription_active_start"` + ChatgptSubscriptionActiveUntil any `json:"chatgpt_subscription_active_until"` + ChatgptSubscriptionLastChecked time.Time `json:"chatgpt_subscription_last_checked"` + ChatgptUserID string `json:"chatgpt_user_id"` + Groups []any `json:"groups"` + Organizations []Organizations `json:"organizations"` + UserID string `json:"user_id"` +} + +// ParseJWTToken parses a JWT token and extracts the claims without verification +// This is used for extracting user information from ID tokens +func ParseJWTToken(token string) (*JWTClaims, error) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid JWT token format: expected 3 parts, got %d", len(parts)) + } + + // Decode the claims (payload) part + claimsData, err := base64URLDecode(parts[1]) + if err != nil { + return nil, fmt.Errorf("failed to decode JWT claims: %w", err) + } + + var claims JWTClaims + if err = json.Unmarshal(claimsData, &claims); err != nil { + return nil, fmt.Errorf("failed to unmarshal JWT claims: %w", err) + } + + return &claims, nil +} + +// base64URLDecode decodes a base64 URL-encoded string with proper padding +func base64URLDecode(data string) ([]byte, error) { + // Add padding if necessary + switch len(data) % 4 { + case 2: + data += "==" + case 3: + data += "=" + } + + return base64.URLEncoding.DecodeString(data) +} + +// GetUserEmail extracts the user email from JWT claims +func (c *JWTClaims) GetUserEmail() string { + return c.Email +} + +// GetAccountID extracts the user ID from JWT claims (subject) +func (c *JWTClaims) GetAccountID() string { + return c.CodexAuthInfo.ChatgptAccountID +} diff --git a/internal/auth/codex/oauth_server.go b/internal/auth/codex/oauth_server.go new file mode 100644 index 00000000..8f8085d2 --- /dev/null +++ b/internal/auth/codex/oauth_server.go @@ -0,0 +1,244 @@ +package codex + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +// OAuthServer handles the local HTTP server for OAuth callbacks +type OAuthServer struct { + server *http.Server + port int + resultChan chan *OAuthResult + errorChan chan error + mu sync.Mutex + running bool +} + +// OAuthResult contains the result of the OAuth callback +type OAuthResult struct { + Code string + State string + Error string +} + +// NewOAuthServer creates a new OAuth callback server +func NewOAuthServer(port int) *OAuthServer { + return &OAuthServer{ + port: port, + resultChan: make(chan *OAuthResult, 1), + errorChan: make(chan error, 1), + } +} + +// Start starts the OAuth callback server +func (s *OAuthServer) Start(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.running { + return fmt.Errorf("server is already running") + } + + // Check if port is available + if !s.isPortAvailable() { + return fmt.Errorf("port %d is already in use", s.port) + } + + mux := http.NewServeMux() + mux.HandleFunc("/auth/callback", s.handleCallback) + mux.HandleFunc("/success", s.handleSuccess) + + s.server = &http.Server{ + Addr: fmt.Sprintf(":%d", s.port), + Handler: mux, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + + s.running = true + + // Start server in goroutine + go func() { + if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.errorChan <- fmt.Errorf("server failed to start: %w", err) + } + }() + + // Give server a moment to start + time.Sleep(100 * time.Millisecond) + + return nil +} + +// Stop gracefully stops the OAuth callback server +func (s *OAuthServer) Stop(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.running || s.server == nil { + return nil + } + + log.Debug("Stopping OAuth callback server") + + // Create a context with timeout for shutdown + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + err := s.server.Shutdown(shutdownCtx) + s.running = false + s.server = nil + + return err +} + +// WaitForCallback waits for the OAuth callback with a timeout +func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) { + select { + case result := <-s.resultChan: + return result, nil + case err := <-s.errorChan: + return nil, err + case <-time.After(timeout): + return nil, fmt.Errorf("timeout waiting for OAuth callback") + } +} + +// handleCallback handles the OAuth callback endpoint +func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) { + log.Debug("Received OAuth callback") + + // Validate request method + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Extract parameters + query := r.URL.Query() + code := query.Get("code") + state := query.Get("state") + errorParam := query.Get("error") + + // Validate required parameters + if errorParam != "" { + log.Errorf("OAuth error received: %s", errorParam) + result := &OAuthResult{ + Error: errorParam, + } + s.sendResult(result) + http.Error(w, fmt.Sprintf("OAuth error: %s", errorParam), http.StatusBadRequest) + return + } + + if code == "" { + log.Error("No authorization code received") + result := &OAuthResult{ + Error: "no_code", + } + s.sendResult(result) + http.Error(w, "No authorization code received", http.StatusBadRequest) + return + } + + if state == "" { + log.Error("No state parameter received") + result := &OAuthResult{ + Error: "no_state", + } + s.sendResult(result) + http.Error(w, "No state parameter received", http.StatusBadRequest) + return + } + + // Send successful result + result := &OAuthResult{ + Code: code, + State: state, + } + s.sendResult(result) + + // Redirect to success page + http.Redirect(w, r, "/success", http.StatusFound) +} + +// handleSuccess handles the success page endpoint +func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { + log.Debug("Serving success page") + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + + // Parse query parameters for customization + query := r.URL.Query() + setupRequired := query.Get("setup_required") == "true" + platformURL := query.Get("platform_url") + if platformURL == "" { + platformURL = "https://platform.openai.com" + } + + // Generate success page HTML with dynamic content + successHTML := s.generateSuccessHTML(setupRequired, platformURL) + + _, err := w.Write([]byte(successHTML)) + if err != nil { + log.Errorf("Failed to write success page: %v", err) + } +} + +// generateSuccessHTML creates the HTML content for the success page +func (s *OAuthServer) generateSuccessHTML(setupRequired bool, platformURL string) string { + html := LoginSuccessHtml + + // Replace platform URL placeholder + html = strings.Replace(html, "{{PLATFORM_URL}}", platformURL, -1) + + // Add setup notice if required + if setupRequired { + setupNotice := strings.Replace(SetupNoticeHtml, "{{PLATFORM_URL}}", platformURL, -1) + html = strings.Replace(html, "{{SETUP_NOTICE}}", setupNotice, 1) + } else { + html = strings.Replace(html, "{{SETUP_NOTICE}}", "", 1) + } + + return html +} + +// sendResult sends the OAuth result to the waiting channel +func (s *OAuthServer) sendResult(result *OAuthResult) { + select { + case s.resultChan <- result: + log.Debug("OAuth result sent to channel") + default: + log.Warn("OAuth result channel is full, result dropped") + } +} + +// isPortAvailable checks if the specified port is available +func (s *OAuthServer) isPortAvailable() bool { + addr := fmt.Sprintf(":%d", s.port) + listener, err := net.Listen("tcp", addr) + if err != nil { + return false + } + defer func() { + _ = listener.Close() + }() + return true +} + +// IsRunning returns whether the server is currently running +func (s *OAuthServer) IsRunning() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.running +} diff --git a/internal/auth/codex/openai.go b/internal/auth/codex/openai.go new file mode 100644 index 00000000..d2583d38 --- /dev/null +++ b/internal/auth/codex/openai.go @@ -0,0 +1,36 @@ +package codex + +// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow +type PKCECodes struct { + // CodeVerifier is the cryptographically random string used to correlate + // the authorization request to the token request + CodeVerifier string `json:"code_verifier"` + // CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded + CodeChallenge string `json:"code_challenge"` +} + +// CodexTokenData holds OAuth token information from OpenAI +type CodexTokenData struct { + // IDToken is the JWT ID token containing user claims + IDToken string `json:"id_token"` + // AccessToken is the OAuth2 access token for API access + AccessToken string `json:"access_token"` + // RefreshToken is used to obtain new access tokens + RefreshToken string `json:"refresh_token"` + // AccountID is the OpenAI account identifier + AccountID string `json:"account_id"` + // Email is the OpenAI account email + Email string `json:"email"` + // Expire is the timestamp of the token expire + Expire string `json:"expired"` +} + +// CodexAuthBundle aggregates authentication data after OAuth flow completion +type CodexAuthBundle struct { + // APIKey is the OpenAI API key obtained from token exchange + APIKey string `json:"api_key"` + // TokenData contains the OAuth tokens from the authentication flow + TokenData CodexTokenData `json:"token_data"` + // LastRefresh is the timestamp of the last token refresh + LastRefresh string `json:"last_refresh"` +} diff --git a/internal/auth/codex/openai_auth.go b/internal/auth/codex/openai_auth.go new file mode 100644 index 00000000..81e1e156 --- /dev/null +++ b/internal/auth/codex/openai_auth.go @@ -0,0 +1,269 @@ +package codex + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/luispater/CLIProxyAPI/internal/config" + "github.com/luispater/CLIProxyAPI/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + openaiAuthURL = "https://auth.openai.com/oauth/authorize" + openaiTokenURL = "https://auth.openai.com/oauth/token" + openaiClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + redirectURI = "http://localhost:1455/auth/callback" +) + +// CodexAuth handles OpenAI OAuth2 authentication flow +type CodexAuth struct { + httpClient *http.Client +} + +// NewCodexAuth creates a new OpenAI authentication service +func NewCodexAuth(cfg *config.Config) *CodexAuth { + return &CodexAuth{ + httpClient: util.SetProxy(cfg, &http.Client{}), + } +} + +// GenerateAuthURL creates the OAuth authorization URL with PKCE +func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string, error) { + if pkceCodes == nil { + return "", fmt.Errorf("PKCE codes are required") + } + + params := url.Values{ + "client_id": {openaiClientID}, + "response_type": {"code"}, + "redirect_uri": {redirectURI}, + "scope": {"openid email profile offline_access"}, + "state": {state}, + "code_challenge": {pkceCodes.CodeChallenge}, + "code_challenge_method": {"S256"}, + "prompt": {"login"}, + "id_token_add_organizations": {"true"}, + "codex_cli_simplified_flow": {"true"}, + } + + authURL := fmt.Sprintf("%s?%s", openaiAuthURL, params.Encode()) + return authURL, nil +} + +// ExchangeCodeForTokens exchanges authorization code for access tokens +func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) { + if pkceCodes == nil { + return nil, fmt.Errorf("PKCE codes are required for token exchange") + } + + // Prepare token exchange request + data := url.Values{ + "grant_type": {"authorization_code"}, + "client_id": {openaiClientID}, + "code": {code}, + "redirect_uri": {redirectURI}, + "code_verifier": {pkceCodes.CodeVerifier}, + } + + req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("failed to create token request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("token exchange request failed: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read token response: %w", err) + } + // log.Debugf("Token response: %s", string(body)) + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse token response + var tokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + } + + if err = json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + // Extract account ID from ID token + claims, err := ParseJWTToken(tokenResp.IDToken) + if err != nil { + log.Warnf("Failed to parse ID token: %v", err) + } + + accountID := "" + email := "" + if claims != nil { + accountID = claims.GetAccountID() + email = claims.GetUserEmail() + } + + // Create token data + tokenData := CodexTokenData{ + IDToken: tokenResp.IDToken, + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + AccountID: accountID, + Email: email, + Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339), + } + + // Create auth bundle + bundle := &CodexAuthBundle{ + TokenData: tokenData, + LastRefresh: time.Now().Format(time.RFC3339), + } + + return bundle, nil +} + +// RefreshTokens refreshes the access token using the refresh token +func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*CodexTokenData, error) { + if refreshToken == "" { + return nil, fmt.Errorf("refresh token is required") + } + + data := url.Values{ + "client_id": {openaiClientID}, + "grant_type": {"refresh_token"}, + "refresh_token": {refreshToken}, + "scope": {"openid profile email"}, + } + + req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("failed to create refresh request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("token refresh request failed: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read refresh response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body)) + } + + var tokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + } + + if err = json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("failed to parse refresh response: %w", err) + } + + // Extract account ID from ID token + claims, err := ParseJWTToken(tokenResp.IDToken) + if err != nil { + log.Warnf("Failed to parse refreshed ID token: %v", err) + } + + accountID := "" + email := "" + if claims != nil { + accountID = claims.GetAccountID() + email = claims.Email + } + + return &CodexTokenData{ + IDToken: tokenResp.IDToken, + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + AccountID: accountID, + Email: email, + Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339), + }, nil +} + +// CreateTokenStorage creates a new CodexTokenStorage from auth bundle and user info +func (o *CodexAuth) CreateTokenStorage(bundle *CodexAuthBundle) *CodexTokenStorage { + storage := &CodexTokenStorage{ + IDToken: bundle.TokenData.IDToken, + AccessToken: bundle.TokenData.AccessToken, + RefreshToken: bundle.TokenData.RefreshToken, + AccountID: bundle.TokenData.AccountID, + LastRefresh: bundle.LastRefresh, + Email: bundle.TokenData.Email, + Expire: bundle.TokenData.Expire, + } + + return storage +} + +// RefreshTokensWithRetry refreshes tokens with automatic retry logic +func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*CodexTokenData, error) { + var lastErr error + + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + // Wait before retry + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(attempt) * time.Second): + } + } + + tokenData, err := o.RefreshTokens(ctx, refreshToken) + if err == nil { + return tokenData, nil + } + + lastErr = err + log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err) + } + + return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr) +} + +// UpdateTokenStorage updates an existing token storage with new token data +func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) { + storage.IDToken = tokenData.IDToken + storage.AccessToken = tokenData.AccessToken + storage.RefreshToken = tokenData.RefreshToken + storage.AccountID = tokenData.AccountID + storage.LastRefresh = time.Now().Format(time.RFC3339) + storage.Email = tokenData.Email + storage.Expire = tokenData.Expire +} diff --git a/internal/auth/codex/pkce.go b/internal/auth/codex/pkce.go new file mode 100644 index 00000000..a276c6c6 --- /dev/null +++ b/internal/auth/codex/pkce.go @@ -0,0 +1,47 @@ +package codex + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" +) + +// GeneratePKCECodes generates a PKCE code verifier and challenge pair +// following RFC 7636 specifications for OAuth 2.0 PKCE extension +func GeneratePKCECodes() (*PKCECodes, error) { + // Generate code verifier: 43-128 characters, URL-safe + codeVerifier, err := generateCodeVerifier() + if err != nil { + return nil, fmt.Errorf("failed to generate code verifier: %w", err) + } + + // Generate code challenge using S256 method + codeChallenge := generateCodeChallenge(codeVerifier) + + return &PKCECodes{ + CodeVerifier: codeVerifier, + CodeChallenge: codeChallenge, + }, nil +} + +// generateCodeVerifier creates a cryptographically random string +// of 128 characters using URL-safe base64 encoding +func generateCodeVerifier() (string, error) { + // Generate 96 random bytes (will result in 128 base64 characters) + bytes := make([]byte, 96) + _, err := rand.Read(bytes) + if err != nil { + return "", fmt.Errorf("failed to generate random bytes: %w", err) + } + + // Encode to URL-safe base64 without padding + return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(bytes), nil +} + +// generateCodeChallenge creates a SHA256 hash of the code verifier +// and encodes it using URL-safe base64 encoding without padding +func generateCodeChallenge(codeVerifier string) string { + hash := sha256.Sum256([]byte(codeVerifier)) + return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:]) +} diff --git a/internal/auth/codex/token.go b/internal/auth/codex/token.go new file mode 100644 index 00000000..af9cf4d2 --- /dev/null +++ b/internal/auth/codex/token.go @@ -0,0 +1,51 @@ +package codex + +import ( + "encoding/json" + "fmt" + "os" + "path" +) + +// CodexTokenStorage extends the existing GeminiTokenStorage for OpenAI-specific data +// It maintains compatibility with the existing auth system while adding OpenAI-specific fields +type CodexTokenStorage struct { + // IDToken is the JWT ID token containing user claims + IDToken string `json:"id_token"` + // AccessToken is the OAuth2 access token for API access + AccessToken string `json:"access_token"` + // RefreshToken is used to obtain new access tokens + RefreshToken string `json:"refresh_token"` + // AccountID is the OpenAI account identifier + AccountID string `json:"account_id"` + // LastRefresh is the timestamp of the last token refresh + LastRefresh string `json:"last_refresh"` + // Email is the OpenAI account email + Email string `json:"email"` + // Type indicates the type (gemini, chatgpt, claude) of token storage. + Type string `json:"type"` + // Expire is the timestamp of the token expire + Expire string `json:"expired"` +} + +// SaveTokenToFile serializes the token storage to a JSON file. +func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error { + ts.Type = "codex" + if err := os.MkdirAll(path.Dir(authFilePath), 0700); err != nil { + return fmt.Errorf("failed to create directory: %v", err) + } + + f, err := os.Create(authFilePath) + if err != nil { + return fmt.Errorf("failed to create token file: %w", err) + } + defer func() { + _ = f.Close() + }() + + if err = json.NewEncoder(f).Encode(ts); err != nil { + return fmt.Errorf("failed to write token to file: %w", err) + } + return nil + +} diff --git a/internal/auth/auth.go b/internal/auth/gemini/gemini_auth.go similarity index 74% rename from internal/auth/auth.go rename to internal/auth/gemini/gemini_auth.go index 8a67c3c9..c8719452 100644 --- a/internal/auth/auth.go +++ b/internal/auth/gemini/gemini_auth.go @@ -1,7 +1,7 @@ // Package auth provides OAuth2 authentication functionality for Google Cloud APIs. // It handles the complete OAuth2 flow including token storage, web-based authentication, // proxy support, and automatic token refresh. The package supports both SOCKS5 and HTTP/HTTPS proxies. -package auth +package gemini import ( "context" @@ -14,9 +14,10 @@ import ( "net/url" "time" + "github.com/luispater/CLIProxyAPI/internal/auth/codex" + "github.com/luispater/CLIProxyAPI/internal/browser" "github.com/luispater/CLIProxyAPI/internal/config" log "github.com/sirupsen/logrus" - "github.com/skratchdot/open-golang/open" "github.com/tidwall/gjson" "golang.org/x/net/proxy" @@ -25,22 +26,29 @@ import ( ) const ( - oauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" - oauthClientSecret = "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" + geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" + geminiOauthClientSecret = "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" ) var ( - oauthScopes = []string{ + geminiOauthScopes = []string{ "https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile", } ) +type GeminiAuth struct { +} + +func NewGeminiAuth() *GeminiAuth { + return &GeminiAuth{} +} + // GetAuthenticatedClient configures and returns an HTTP client ready for making authenticated API calls. // It manages the entire OAuth2 flow, including handling proxies, loading existing tokens, // initiating a new web-based OAuth flow if necessary, and refreshing tokens. -func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) { +func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, noBrowser ...bool) (*http.Client, error) { // Configure proxy settings for the HTTP client if a proxy URL is provided. proxyURL, err := url.Parse(cfg.ProxyURL) if err == nil { @@ -72,10 +80,10 @@ func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.C // Configure the OAuth2 client. conf := &oauth2.Config{ - ClientID: oauthClientID, - ClientSecret: oauthClientSecret, + ClientID: geminiOauthClientID, + ClientSecret: geminiOauthClientSecret, RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server. - Scopes: oauthScopes, + Scopes: geminiOauthScopes, Endpoint: google.Endpoint, } @@ -84,12 +92,12 @@ func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.C // If no token is found in storage, initiate the web-based OAuth flow. if ts.Token == nil { log.Info("Could not load token from file, starting OAuth flow.") - token, err = getTokenFromWeb(ctx, conf) + token, err = g.getTokenFromWeb(ctx, conf, noBrowser...) if err != nil { return nil, fmt.Errorf("failed to get token from web: %w", err) } // After getting a new token, create a new token storage object with user info. - newTs, errCreateTokenStorage := createTokenStorage(ctx, conf, token, ts.ProjectID) + newTs, errCreateTokenStorage := g.createTokenStorage(ctx, conf, token, ts.ProjectID) if errCreateTokenStorage != nil { log.Errorf("Warning: failed to create token storage: %v", errCreateTokenStorage) return nil, errCreateTokenStorage @@ -107,9 +115,9 @@ func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.C return conf.Client(ctx, token), nil } -// createTokenStorage creates a new TokenStorage object. It fetches the user's email +// createTokenStorage creates a new GeminiTokenStorage object. It fetches the user's email // using the provided token and populates the storage structure. -func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*TokenStorage, error) { +func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*GeminiTokenStorage, error) { httpClient := config.Client(ctx, token) req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if err != nil { @@ -148,12 +156,12 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth } ifToken["token_uri"] = "https://oauth2.googleapis.com/token" - ifToken["client_id"] = oauthClientID - ifToken["client_secret"] = oauthClientSecret - ifToken["scopes"] = oauthScopes + ifToken["client_id"] = geminiOauthClientID + ifToken["client_secret"] = geminiOauthClientSecret + ifToken["scopes"] = geminiOauthScopes ifToken["universe_domain"] = "googleapis.com" - ts := TokenStorage{ + ts := GeminiTokenStorage{ Token: ifToken, ProjectID: projectID, Email: emailResult.String(), @@ -166,7 +174,7 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth // It starts a local HTTP server to listen for the callback from Google's auth server, // opens the user's browser to the authorization URL, and exchanges the received // authorization code for an access token. -func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) { +func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, noBrowser ...bool) (*oauth2.Token, error) { // Use a channel to pass the authorization code from the HTTP handler to the main function. codeChan := make(chan string) errChan := make(chan error) @@ -201,27 +209,46 @@ func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, // Open the authorization URL in the user's browser. authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) - log.Debugf("CLI login required.\nAttempting to open authentication page in your browser.\nIf it does not open, please navigate to this URL:\n\n%s\n", authURL) - var err error - err = open.Run(authURL) - if err != nil { - log.Errorf("Failed to open browser: %v. Please open the URL manually.", err) + if len(noBrowser) == 1 && !noBrowser[0] { + log.Info("Opening browser for authentication...") + + // Check if browser is available + if !browser.IsAvailable() { + log.Warn("No browser available on this system") + log.Infof("Please manually open this URL in your browser:\n\n%s\n", authURL) + } else { + if err := browser.OpenURL(authURL); err != nil { + authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err) + log.Warn(codex.GetUserFriendlyMessage(authErr)) + log.Infof("Please manually open this URL in your browser:\n\n%s\n", authURL) + + // Log platform info for debugging + platformInfo := browser.GetPlatformInfo() + log.Debugf("Browser platform info: %+v", platformInfo) + } else { + log.Debug("Browser opened successfully") + } + } + } else { + log.Infof("Please open this URL in your browser:\n\n%s\n", authURL) } + log.Info("Waiting for authentication callback...") + // Wait for the authorization code or an error. var authCode string select { case code := <-codeChan: authCode = code - case err = <-errChan: + case err := <-errChan: return nil, err case <-time.After(5 * time.Minute): // Timeout return nil, fmt.Errorf("oauth flow timed out") } // Shutdown the server. - if err = server.Shutdown(ctx); err != nil { + if err := server.Shutdown(ctx); err != nil { log.Errorf("Failed to shut down server: %v", err) } diff --git a/internal/auth/gemini/gemini_token.go b/internal/auth/gemini/gemini_token.go new file mode 100644 index 00000000..49712d6e --- /dev/null +++ b/internal/auth/gemini/gemini_token.go @@ -0,0 +1,64 @@ +// Package gemini provides authentication and token management functionality +// for Google's Gemini AI services. It handles OAuth2 token storage, serialization, +// and retrieval for maintaining authenticated sessions with the Gemini API. +package gemini + +import ( + "encoding/json" + "fmt" + "os" + "path" +) + +// GeminiTokenStorage defines the structure for storing OAuth2 token information, +// along with associated user and project details. This data is typically +// serialized to a JSON file for persistence. +type GeminiTokenStorage struct { + // Token holds the raw OAuth2 token data, including access and refresh tokens. + Token any `json:"token"` + + // ProjectID is the Google Cloud Project ID associated with this token. + ProjectID string `json:"project_id"` + + // Email is the email address of the authenticated user. + Email string `json:"email"` + + // Auto indicates if the project ID was automatically selected. + Auto bool `json:"auto"` + + // Checked indicates if the associated Cloud AI API has been verified as enabled. + Checked bool `json:"checked"` + + // Type indicates the type (gemini, chatgpt, claude) of token storage. + Type string `json:"type"` +} + +// SaveTokenToFile serializes the token storage to a JSON file. +// This method creates the necessary directory structure and writes the token +// data in JSON format to the specified file path. It ensures the file is +// properly closed after writing. +// +// Parameters: +// - authFilePath: The full path where the token file should be saved +// +// Returns: +// - error: An error if the operation fails, nil otherwise +func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error { + ts.Type = "gemini" + if err := os.MkdirAll(path.Dir(authFilePath), 0700); err != nil { + return fmt.Errorf("failed to create directory: %v", err) + } + + f, err := os.Create(authFilePath) + if err != nil { + return fmt.Errorf("failed to create token file: %w", err) + } + defer func() { + _ = f.Close() + }() + + if err = json.NewEncoder(f).Encode(ts); err != nil { + return fmt.Errorf("failed to write token to file: %w", err) + } + return nil +} diff --git a/internal/auth/models.go b/internal/auth/models.go index 33c03745..16f53f72 100644 --- a/internal/auth/models.go +++ b/internal/auth/models.go @@ -1,17 +1,5 @@ package auth -// TokenStorage defines the structure for storing OAuth2 token information, -// along with associated user and project details. This data is typically -// serialized to a JSON file for persistence. -type TokenStorage struct { - // Token holds the raw OAuth2 token data, including access and refresh tokens. - Token any `json:"token"` - // ProjectID is the Google Cloud Project ID associated with this token. - ProjectID string `json:"project_id"` - // Email is the email address of the authenticated user. - Email string `json:"email"` - // Auto indicates if the project ID was automatically selected. - Auto bool `json:"auto"` - // Checked indicates if the associated Cloud AI API has been verified as enabled. - Checked bool `json:"checked"` +type TokenStorage interface { + SaveTokenToFile(authFilePath string) error } diff --git a/internal/browser/browser.go b/internal/browser/browser.go new file mode 100644 index 00000000..df783cfb --- /dev/null +++ b/internal/browser/browser.go @@ -0,0 +1,121 @@ +package browser + +import ( + "fmt" + "os/exec" + "runtime" + + log "github.com/sirupsen/logrus" + "github.com/skratchdot/open-golang/open" +) + +// OpenURL opens a URL in the default browser +func OpenURL(url string) error { + log.Debugf("Attempting to open URL in browser: %s", url) + + // Try using the open-golang library first + err := open.Run(url) + if err == nil { + log.Debug("Successfully opened URL using open-golang library") + return nil + } + + log.Debugf("open-golang failed: %v, trying platform-specific commands", err) + + // Fallback to platform-specific commands + return openURLPlatformSpecific(url) +} + +// openURLPlatformSpecific opens URL using platform-specific commands +func openURLPlatformSpecific(url string) error { + var cmd *exec.Cmd + + switch runtime.GOOS { + case "darwin": // macOS + cmd = exec.Command("open", url) + case "windows": + cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url) + case "linux": + // Try common Linux browsers in order of preference + browsers := []string{"xdg-open", "x-www-browser", "www-browser", "firefox", "chromium", "google-chrome"} + for _, browser := range browsers { + if _, err := exec.LookPath(browser); err == nil { + cmd = exec.Command(browser, url) + break + } + } + if cmd == nil { + return fmt.Errorf("no suitable browser found on Linux system") + } + default: + return fmt.Errorf("unsupported operating system: %s", runtime.GOOS) + } + + log.Debugf("Running command: %s %v", cmd.Path, cmd.Args[1:]) + err := cmd.Start() + if err != nil { + return fmt.Errorf("failed to start browser command: %w", err) + } + + log.Debug("Successfully opened URL using platform-specific command") + return nil +} + +// IsAvailable checks if browser opening functionality is available +func IsAvailable() bool { + // First check if open-golang can work + testErr := open.Run("about:blank") + if testErr == nil { + return true + } + + // Check platform-specific commands + switch runtime.GOOS { + case "darwin": + _, err := exec.LookPath("open") + return err == nil + case "windows": + _, err := exec.LookPath("rundll32") + return err == nil + case "linux": + browsers := []string{"xdg-open", "x-www-browser", "www-browser", "firefox", "chromium", "google-chrome"} + for _, browser := range browsers { + if _, err := exec.LookPath(browser); err == nil { + return true + } + } + return false + default: + return false + } +} + +// GetPlatformInfo returns information about the current platform's browser support +func GetPlatformInfo() map[string]interface{} { + info := map[string]interface{}{ + "os": runtime.GOOS, + "arch": runtime.GOARCH, + "available": IsAvailable(), + } + + switch runtime.GOOS { + case "darwin": + info["default_command"] = "open" + case "windows": + info["default_command"] = "rundll32" + case "linux": + browsers := []string{"xdg-open", "x-www-browser", "www-browser", "firefox", "chromium", "google-chrome"} + availableBrowsers := []string{} + for _, browser := range browsers { + if _, err := exec.LookPath(browser); err == nil { + availableBrowsers = append(availableBrowsers, browser) + } + } + info["available_browsers"] = availableBrowsers + if len(availableBrowsers) > 0 { + info["default_command"] = availableBrowsers[0] + } + } + + return info +} diff --git a/internal/client/client.go b/internal/client/client.go index d25dd8de..0bfb6073 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -1,975 +1,87 @@ -// Package client provides HTTP client functionality for interacting with Google Cloud AI APIs. -// It handles OAuth2 authentication, token management, request/response processing, -// streaming communication, quota management, and automatic model fallback. -// The package supports both direct API key authentication and OAuth2 flows. +// Package client defines the interface and base structure for AI API clients. +// It provides a common interface that all supported AI service clients must implement, +// including methods for sending messages, handling streams, and managing authentication. package client import ( - "bufio" - "bytes" "context" - "encoding/json" - "fmt" - "io" "net/http" - "os" - "path/filepath" - "runtime" - "strings" "sync" "time" "github.com/luispater/CLIProxyAPI/internal/auth" "github.com/luispater/CLIProxyAPI/internal/config" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - "golang.org/x/oauth2" ) -const ( - codeAssistEndpoint = "https://cloudcode-pa.googleapis.com" - apiVersion = "v1internal" - pluginVersion = "0.1.9" +// Client defines the interface that all AI API clients must implement. +// This interface provides methods for interacting with various AI services +// including sending messages, streaming responses, and managing authentication. +type Client interface { + // GetRequestMutex returns the mutex used to synchronize requests for this client. + // This ensures that only one request is processed at a time for quota management. + GetRequestMutex() *sync.Mutex - glEndPoint = "https://generativelanguage.googleapis.com" - glAPIVersion = "v1beta" -) + // GetUserAgent returns the User-Agent string used for HTTP requests. + GetUserAgent() string -var ( - previewModels = map[string][]string{ - "gemini-2.5-pro": {"gemini-2.5-pro-preview-05-06", "gemini-2.5-pro-preview-06-05"}, - "gemini-2.5-flash": {"gemini-2.5-flash-preview-04-17", "gemini-2.5-flash-preview-05-20"}, - } -) + // SendMessage sends a single message to the AI service and returns the response. + // It takes the raw JSON request, model name, system instructions, conversation contents, + // and tool declarations, then returns the response bytes and any error that occurred. + SendMessage(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) -// Client is the main client for interacting with the CLI API. -type Client struct { - httpClient *http.Client - RequestMutex sync.Mutex - tokenStorage *auth.TokenStorage - cfg *config.Config + // SendMessageStream sends a message to the AI service and returns streaming responses. + // It takes similar parameters to SendMessage but returns channels for streaming data + // and errors, enabling real-time response processing. + SendMessageStream(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration, includeThoughts ...bool) (<-chan []byte, <-chan *ErrorMessage) + + // SendRawMessage sends a raw JSON message to the AI service without translation. + // This method is used when the request is already in the service's native format. + SendRawMessage(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) + + // SendRawMessageStream sends a raw JSON message and returns streaming responses. + // Similar to SendRawMessage but for streaming responses. + SendRawMessageStream(ctx context.Context, rawJSON []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) + + // SendRawTokenCount sends a token count request to the AI service. + // This method is used to estimate the number of tokens in a given text. + SendRawTokenCount(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) + + // SaveTokenToFile saves the client's authentication token to a file. + // This is used for persisting authentication state between sessions. + SaveTokenToFile() error + + // IsModelQuotaExceeded checks if the specified model has exceeded its quota. + // This helps with load balancing and automatic failover to alternative models. + IsModelQuotaExceeded(model string) bool + + // GetEmail returns the email associated with the client's authentication. + // This is used for logging and identification purposes. + GetEmail() string +} + +// ClientBase provides a common base structure for all AI API clients. +// It implements shared functionality such as request synchronization, HTTP client management, +// configuration access, token storage, and quota tracking. +type ClientBase struct { + // RequestMutex ensures only one request is processed at a time for quota management. + RequestMutex *sync.Mutex + + // httpClient is the HTTP client used for making API requests. + httpClient *http.Client + + // cfg holds the application configuration. + cfg *config.Config + + // tokenStorage manages authentication tokens for the client. + tokenStorage auth.TokenStorage + + // modelQuotaExceeded tracks when models have exceeded their quota. + // The map key is the model name, and the value is the time when the quota was exceeded. modelQuotaExceeded map[string]*time.Time - glAPIKey string } -// NewClient creates a new CLI API client. -func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Config, glAPIKey ...string) *Client { - var glKey string - if len(glAPIKey) > 0 { - glKey = glAPIKey[0] - } - return &Client{ - httpClient: httpClient, - tokenStorage: ts, - cfg: cfg, - modelQuotaExceeded: make(map[string]*time.Time), - glAPIKey: glKey, - } -} - -// SetProjectID updates the project ID for the client's token storage. -func (c *Client) SetProjectID(projectID string) { - c.tokenStorage.ProjectID = projectID -} - -// SetIsAuto configures whether the client should operate in automatic mode. -func (c *Client) SetIsAuto(auto bool) { - c.tokenStorage.Auto = auto -} - -// SetIsChecked sets the checked status for the client's token storage. -func (c *Client) SetIsChecked(checked bool) { - c.tokenStorage.Checked = checked -} - -// IsChecked returns whether the client's token storage has been checked. -func (c *Client) IsChecked() bool { - return c.tokenStorage.Checked -} - -// IsAuto returns whether the client is operating in automatic mode. -func (c *Client) IsAuto() bool { - return c.tokenStorage.Auto -} - -// GetEmail returns the email address associated with the client's token storage. -func (c *Client) GetEmail() string { - return c.tokenStorage.Email -} - -// GetProjectID returns the Google Cloud project ID from the client's token storage. -func (c *Client) GetProjectID() string { - if c.tokenStorage != nil { - return c.tokenStorage.ProjectID - } - return "" -} - -// GetGenerativeLanguageAPIKey returns the generative language API key if configured. -func (c *Client) GetGenerativeLanguageAPIKey() string { - return c.glAPIKey -} - -// SetupUser performs the initial user onboarding and setup. -func (c *Client) SetupUser(ctx context.Context, email, projectID string) error { - c.tokenStorage.Email = email - log.Info("Performing user onboarding...") - - // 1. LoadCodeAssist - loadAssistReqBody := map[string]interface{}{ - "metadata": getClientMetadata(), - } - if projectID != "" { - loadAssistReqBody["cloudaicompanionProject"] = projectID - } - - var loadAssistResp map[string]interface{} - err := c.makeAPIRequest(ctx, "loadCodeAssist", "POST", loadAssistReqBody, &loadAssistResp) - if err != nil { - return fmt.Errorf("failed to load code assist: %w", err) - } - - // a, _ := json.Marshal(&loadAssistResp) - // log.Debug(string(a)) - // - // a, _ = json.Marshal(loadAssistReqBody) - // log.Debug(string(a)) - - // 2. OnboardUser - var onboardTierID = "legacy-tier" - if tiers, ok := loadAssistResp["allowedTiers"].([]interface{}); ok { - for _, t := range tiers { - if tier, tierOk := t.(map[string]interface{}); tierOk { - if isDefault, isDefaultOk := tier["isDefault"].(bool); isDefaultOk && isDefault { - if id, idOk := tier["id"].(string); idOk { - onboardTierID = id - break - } - } - } - } - } - - onboardProjectID := projectID - if p, ok := loadAssistResp["cloudaicompanionProject"].(string); ok && p != "" { - onboardProjectID = p - } - - onboardReqBody := map[string]interface{}{ - "tierId": onboardTierID, - "metadata": getClientMetadata(), - } - if onboardProjectID != "" { - onboardReqBody["cloudaicompanionProject"] = onboardProjectID - } else { - return fmt.Errorf("failed to start user onboarding, need define a project id") - } - - for { - var lroResp map[string]interface{} - err = c.makeAPIRequest(ctx, "onboardUser", "POST", onboardReqBody, &lroResp) - if err != nil { - return fmt.Errorf("failed to start user onboarding: %w", err) - } - // a, _ := json.Marshal(&lroResp) - // log.Debug(string(a)) - - // 3. Poll Long-Running Operation (LRO) - done, doneOk := lroResp["done"].(bool) - if doneOk && done { - if project, projectOk := lroResp["response"].(map[string]interface{})["cloudaicompanionProject"].(map[string]interface{}); projectOk { - if projectID != "" { - c.tokenStorage.ProjectID = projectID - } else { - c.tokenStorage.ProjectID = project["id"].(string) - } - log.Infof("Onboarding complete. Using Project ID: %s", c.tokenStorage.ProjectID) - return nil - } - } else { - log.Println("Onboarding in progress, waiting 5 seconds...") - time.Sleep(5 * time.Second) - } - } -} - -// makeAPIRequest handles making requests to the CLI API endpoints. -func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, body interface{}, result interface{}) error { - var reqBody io.Reader - if body != nil { - jsonBody, err := json.Marshal(body) - if err != nil { - return fmt.Errorf("failed to marshal request body: %w", err) - } - reqBody = bytes.NewBuffer(jsonBody) - } - - url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint) - if strings.HasPrefix(endpoint, "operations/") { - url = fmt.Sprintf("%s/%s", codeAssistEndpoint, endpoint) - } - - req, err := http.NewRequestWithContext(ctx, method, url, reqBody) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token() - if err != nil { - return fmt.Errorf("failed to get token: %w", err) - } - - // Set headers - metadataStr := getClientMetadataString() - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", getUserAgent()) - req.Header.Set("X-Goog-Api-Client", "gl-node/22.17.0") - req.Header.Set("Client-Metadata", metadataStr) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) - - resp, err := c.httpClient.Do(req) - if err != nil { - return fmt.Errorf("failed to execute request: %w", err) - } - defer func() { - if err = resp.Body.Close(); err != nil { - log.Printf("warn: failed to close response body: %v", err) - } - }() - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - bodyBytes, _ := io.ReadAll(resp.Body) - return fmt.Errorf("api request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - } - - if result != nil { - if err = json.NewDecoder(resp.Body).Decode(result); err != nil { - return fmt.Errorf("failed to decode response body: %w", err) - } - } - - return nil -} - -// APIRequest handles making requests to the CLI API endpoints. -func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, alt string, stream bool) (io.ReadCloser, *ErrorMessage) { - var jsonBody []byte - var err error - if byteBody, ok := body.([]byte); ok { - jsonBody = byteBody - } else { - jsonBody, err = json.Marshal(body) - if err != nil { - return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err)} - } - } - - var url string - if c.glAPIKey == "" { - // Add alt=sse for streaming - url = fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint) - if alt == "" && stream { - url = url + "?alt=sse" - } else { - if alt != "" { - url = url + fmt.Sprintf("?$alt=%s", alt) - } - } - } else { - if endpoint == "countTokens" { - modelResult := gjson.GetBytes(jsonBody, "model") - url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glAPIVersion, modelResult.String(), endpoint) - } else { - modelResult := gjson.GetBytes(jsonBody, "model") - url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glAPIVersion, modelResult.String(), endpoint) - if alt == "" && stream { - url = url + "?alt=sse" - } else { - if alt != "" { - url = url + fmt.Sprintf("?$alt=%s", alt) - } - } - jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw) - systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction") - if systemInstructionResult.Exists() { - jsonBody, _ = sjson.SetRawBytes(jsonBody, "system_instruction", []byte(systemInstructionResult.Raw)) - jsonBody, _ = sjson.DeleteBytes(jsonBody, "systemInstruction") - jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id") - } - } - } - - // log.Debug(string(jsonBody)) - // log.Debug(url) - reqBody := bytes.NewBuffer(jsonBody) - - req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody) - if err != nil { - return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %v", err)} - } - - // Set headers - metadataStr := getClientMetadataString() - req.Header.Set("Content-Type", "application/json") - if c.glAPIKey == "" { - token, errToken := c.httpClient.Transport.(*oauth2.Transport).Source.Token() - if errToken != nil { - return nil, &ErrorMessage{500, fmt.Errorf("failed to get token: %v", errToken)} - } - req.Header.Set("User-Agent", getUserAgent()) - req.Header.Set("X-Goog-Api-Client", "gl-node/22.17.0") - req.Header.Set("Client-Metadata", metadataStr) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) - } else { - req.Header.Set("x-goog-api-key", c.glAPIKey) - } - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %v", err)} - } - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - defer func() { - if err = resp.Body.Close(); err != nil { - log.Printf("warn: failed to close response body: %v", err) - } - }() - bodyBytes, _ := io.ReadAll(resp.Body) - // log.Debug(string(jsonBody)) - return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))} - } - - return resp.Body, nil -} - -// SendMessage handles a single conversational turn, including tool calls. -func (c *Client) SendMessage(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) { - request := GenerateContentRequest{ - Contents: contents, - GenerationConfig: GenerationConfig{ - ThinkingConfig: GenerationConfigThinkingConfig{ - IncludeThoughts: true, - }, - }, - } - - request.SystemInstruction = systemInstruction - - request.Tools = tools - - requestBody := map[string]interface{}{ - "project": c.GetProjectID(), // Assuming ProjectID is available - "request": request, - "model": model, - } - - byteRequestBody, _ := json.Marshal(requestBody) - - // log.Debug(string(byteRequestBody)) - - reasoningEffortResult := gjson.GetBytes(rawJSON, "reasoning_effort") - if reasoningEffortResult.String() == "none" { - byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts") - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0) - } else if reasoningEffortResult.String() == "auto" { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) - } else if reasoningEffortResult.String() == "low" { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024) - } else if reasoningEffortResult.String() == "medium" { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192) - } else if reasoningEffortResult.String() == "high" { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576) - } else { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) - } - - temperatureResult := gjson.GetBytes(rawJSON, "temperature") - if temperatureResult.Exists() && temperatureResult.Type == gjson.Number { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num) - } - - topPResult := gjson.GetBytes(rawJSON, "top_p") - if topPResult.Exists() && topPResult.Type == gjson.Number { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num) - } - - topKResult := gjson.GetBytes(rawJSON, "top_k") - if topKResult.Exists() && topKResult.Type == gjson.Number { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num) - } - - modelName := model - // log.Debug(string(byteRequestBody)) - for { - if c.isModelQuotaExceeded(modelName) { - if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { - modelName = c.getPreviewModel(model) - if modelName != "" { - log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName) - continue - } - } - return nil, &ErrorMessage{ - StatusCode: 429, - Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model), - } - } - - respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, "", false) - if err != nil { - if err.StatusCode == 429 { - now := time.Now() - c.modelQuotaExceeded[modelName] = &now - if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { - continue - } - } - return nil, err - } - delete(c.modelQuotaExceeded, modelName) - bodyBytes, errReadAll := io.ReadAll(respBody) - if errReadAll != nil { - return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll} - } - return bodyBytes, nil - } -} - -// SendMessageStream handles streaming conversational turns with comprehensive parameter management. -// This function implements a sophisticated streaming system that supports tool calls, reasoning modes, -// quota management, and automatic model fallback. It returns two channels for asynchronous communication: -// one for streaming response data and another for error handling. -func (c *Client) SendMessageStream(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration, includeThoughts ...bool) (<-chan []byte, <-chan *ErrorMessage) { - // Define the data prefix used in Server-Sent Events streaming format - dataTag := []byte("data: ") - - // Create channels for asynchronous communication - // errChan: delivers error messages during streaming - // dataChan: delivers response data chunks - errChan := make(chan *ErrorMessage) - dataChan := make(chan []byte) - - // Launch a goroutine to handle the streaming process asynchronously - // This allows the function to return immediately while processing continues in the background - go func() { - // Ensure channels are properly closed when the goroutine exits - defer close(errChan) - defer close(dataChan) - - // Configure thinking/reasoning capabilities - // Default to including thoughts unless explicitly disabled - includeThoughtsFlag := true - if len(includeThoughts) > 0 { - includeThoughtsFlag = includeThoughts[0] - } - - // Build the base request structure for the Gemini API - // This includes conversation contents and generation configuration - request := GenerateContentRequest{ - Contents: contents, - GenerationConfig: GenerationConfig{ - ThinkingConfig: GenerationConfigThinkingConfig{ - IncludeThoughts: includeThoughtsFlag, - }, - }, - } - - // Add system instructions if provided - // System instructions guide the AI's behavior and response style - request.SystemInstruction = systemInstruction - - // Add available tools for function calling capabilities - // Tools allow the AI to perform actions beyond text generation - request.Tools = tools - - // Construct the complete request body with project context - // The project ID is essential for proper API routing and billing - requestBody := map[string]interface{}{ - "project": c.GetProjectID(), // Project ID for API routing and quota management - "request": request, - "model": model, - } - - // Serialize the request body to JSON for API transmission - byteRequestBody, _ := json.Marshal(requestBody) - - // Parse and configure reasoning effort levels from the original request - // This maps Claude-style reasoning effort parameters to Gemini's thinking budget system - reasoningEffortResult := gjson.GetBytes(rawJSON, "reasoning_effort") - if reasoningEffortResult.String() == "none" { - // Disable thinking entirely for fastest responses - byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts") - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0) - } else if reasoningEffortResult.String() == "auto" { - // Let the model decide the appropriate thinking budget automatically - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) - } else if reasoningEffortResult.String() == "low" { - // Minimal thinking for simple tasks (1KB thinking budget) - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024) - } else if reasoningEffortResult.String() == "medium" { - // Moderate thinking for complex tasks (8KB thinking budget) - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192) - } else if reasoningEffortResult.String() == "high" { - // Maximum thinking for very complex tasks (24KB thinking budget) - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576) - } else { - // Default to automatic thinking budget if no specific level is provided - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) - } - - // Configure temperature parameter for response randomness control - // Temperature affects the creativity vs consistency trade-off in responses - temperatureResult := gjson.GetBytes(rawJSON, "temperature") - if temperatureResult.Exists() && temperatureResult.Type == gjson.Number { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num) - } - - // Configure top-p parameter for nucleus sampling - // Controls the cumulative probability threshold for token selection - topPResult := gjson.GetBytes(rawJSON, "top_p") - if topPResult.Exists() && topPResult.Type == gjson.Number { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num) - } - - // Configure top-k parameter for limiting token candidates - // Restricts the model to consider only the top K most likely tokens - topKResult := gjson.GetBytes(rawJSON, "top_k") - if topKResult.Exists() && topKResult.Type == gjson.Number { - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num) - } - - // Initialize model name for quota management and potential fallback - modelName := model - var stream io.ReadCloser - - // Quota management and model fallback loop - // This loop handles quota exceeded scenarios and automatic model switching - for { - // Check if the current model has exceeded its quota - if c.isModelQuotaExceeded(modelName) { - // Attempt to switch to a preview model if configured and using account auth - if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { - modelName = c.getPreviewModel(model) - if modelName != "" { - log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) - // Update the request body with the new model name - byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName) - continue // Retry with the preview model - } - } - // If no fallback is available, return a quota exceeded error - errChan <- &ErrorMessage{ - StatusCode: 429, - Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model), - } - return - } - - // Attempt to establish a streaming connection with the API - var err *ErrorMessage - stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, "", true) - if err != nil { - // Handle quota exceeded errors by marking the model and potentially retrying - if err.StatusCode == 429 { - now := time.Now() - c.modelQuotaExceeded[modelName] = &now // Mark model as quota exceeded - // If preview model switching is enabled, retry the loop - if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { - continue - } - } - // Forward other errors to the error channel - errChan <- err - return - } - // Clear any previous quota exceeded status for this model - delete(c.modelQuotaExceeded, modelName) - break // Successfully established connection, exit the retry loop - } - - // Process the streaming response using a scanner - // This handles the Server-Sent Events format from the API - scanner := bufio.NewScanner(stream) - for scanner.Scan() { - line := scanner.Bytes() - // Filter and forward only data lines (those prefixed with "data: ") - // This extracts the actual JSON content from the SSE format - if bytes.HasPrefix(line, dataTag) { - dataChan <- line[6:] // Remove "data: " prefix and send the JSON content - } - } - - // Handle any scanning errors that occurred during stream processing - if errScanner := scanner.Err(); errScanner != nil { - // Send a 500 Internal Server Error for scanning failures - errChan <- &ErrorMessage{500, errScanner} - _ = stream.Close() - return - } - - // Ensure the stream is properly closed to prevent resource leaks - _ = stream.Close() - }() - - // Return the channels immediately for asynchronous communication - // The caller can read from these channels while the goroutine processes the request - return dataChan, errChan -} - -// SendRawTokenCount handles a token count. -func (c *Client) SendRawTokenCount(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) { - modelResult := gjson.GetBytes(rawJSON, "model") - model := modelResult.String() - modelName := model - for { - if c.isModelQuotaExceeded(modelName) { - if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { - modelName = c.getPreviewModel(model) - if modelName != "" { - log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) - rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName) - continue - } - } - return nil, &ErrorMessage{ - StatusCode: 429, - Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model), - } - } - - respBody, err := c.APIRequest(ctx, "countTokens", rawJSON, alt, false) - if err != nil { - if err.StatusCode == 429 { - now := time.Now() - c.modelQuotaExceeded[modelName] = &now - if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { - continue - } - } - return nil, err - } - delete(c.modelQuotaExceeded, modelName) - bodyBytes, errReadAll := io.ReadAll(respBody) - if errReadAll != nil { - return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll} - } - return bodyBytes, nil - } -} - -// SendRawMessage handles a single conversational turn, including tool calls. -func (c *Client) SendRawMessage(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) { - if c.glAPIKey == "" { - rawJSON, _ = sjson.SetBytes(rawJSON, "project", c.GetProjectID()) - } - - modelResult := gjson.GetBytes(rawJSON, "model") - model := modelResult.String() - modelName := model - for { - if c.isModelQuotaExceeded(modelName) { - if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { - modelName = c.getPreviewModel(model) - if modelName != "" { - log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) - rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName) - continue - } - } - return nil, &ErrorMessage{ - StatusCode: 429, - Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model), - } - } - - respBody, err := c.APIRequest(ctx, "generateContent", rawJSON, alt, false) - if err != nil { - if err.StatusCode == 429 { - now := time.Now() - c.modelQuotaExceeded[modelName] = &now - if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { - continue - } - } - return nil, err - } - delete(c.modelQuotaExceeded, modelName) - bodyBytes, errReadAll := io.ReadAll(respBody) - if errReadAll != nil { - return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll} - } - return bodyBytes, nil - } -} - -// SendRawMessageStream handles a single conversational turn, including tool calls. -func (c *Client) SendRawMessageStream(ctx context.Context, rawJSON []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) { - dataTag := []byte("data: ") - errChan := make(chan *ErrorMessage) - dataChan := make(chan []byte) - go func() { - defer close(errChan) - defer close(dataChan) - - if c.glAPIKey == "" { - rawJSON, _ = sjson.SetBytes(rawJSON, "project", c.GetProjectID()) - } - - modelResult := gjson.GetBytes(rawJSON, "model") - model := modelResult.String() - modelName := model - var stream io.ReadCloser - for { - if c.isModelQuotaExceeded(modelName) { - if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { - modelName = c.getPreviewModel(model) - if modelName != "" { - log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) - rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName) - continue - } - } - errChan <- &ErrorMessage{ - StatusCode: 429, - Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model), - } - return - } - var err *ErrorMessage - stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJSON, alt, true) - if err != nil { - if err.StatusCode == 429 { - now := time.Now() - c.modelQuotaExceeded[modelName] = &now - if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { - continue - } - } - errChan <- err - return - } - delete(c.modelQuotaExceeded, modelName) - break - } - - if alt == "" { - scanner := bufio.NewScanner(stream) - for scanner.Scan() { - line := scanner.Bytes() - if bytes.HasPrefix(line, dataTag) { - dataChan <- line[6:] - } - } - - if errScanner := scanner.Err(); errScanner != nil { - errChan <- &ErrorMessage{500, errScanner} - _ = stream.Close() - return - } - - } else { - data, err := io.ReadAll(stream) - if err != nil { - errChan <- &ErrorMessage{500, err} - _ = stream.Close() - return - } - dataChan <- data - } - _ = stream.Close() - - }() - - return dataChan, errChan -} - -// isModelQuotaExceeded checks if the specified model has exceeded its quota -// within the last 30 minutes. -func (c *Client) isModelQuotaExceeded(model string) bool { - if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey { - duration := time.Now().Sub(*lastExceededTime) - if duration > 30*time.Minute { - return false - } - return true - } - return false -} - -// getPreviewModel returns an available preview model for the given base model, -// or an empty string if no preview models are available or all are quota exceeded. -func (c *Client) getPreviewModel(model string) string { - if models, hasKey := previewModels[model]; hasKey { - for i := 0; i < len(models); i++ { - if !c.isModelQuotaExceeded(models[i]) { - return models[i] - } - } - } - return "" -} - -// IsModelQuotaExceeded returns true if the specified model has exceeded its quota -// and no fallback options are available. -func (c *Client) IsModelQuotaExceeded(model string) bool { - if c.isModelQuotaExceeded(model) { - if c.cfg.QuotaExceeded.SwitchPreviewModel { - return c.getPreviewModel(model) == "" - } - return true - } - return false -} - -// CheckCloudAPIIsEnabled sends a simple test request to the API to verify -// that the Cloud AI API is enabled for the user's project. It provides -// an activation URL if the API is disabled. -func (c *Client) CheckCloudAPIIsEnabled() (bool, error) { - ctx, cancel := context.WithCancel(context.Background()) - defer func() { - c.RequestMutex.Unlock() - cancel() - }() - c.RequestMutex.Lock() - - // A simple request to test the API endpoint. - requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.ProjectID) - - stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), "", true) - if err != nil { - // If a 403 Forbidden error occurs, it likely means the API is not enabled. - if err.StatusCode == 403 { - errJSON := err.Error.Error() - // Check for a specific error code and extract the activation URL. - if gjson.Get(errJSON, "0.error.code").Int() == 403 { - activationURL := gjson.Get(errJSON, "0.error.details.0.metadata.activationUrl").String() - if activationURL != "" { - log.Warnf( - "\n\nPlease activate your account with this url:\n\n%s\n\n And execute this command again:\n%s --login --project_id %s", - activationURL, - os.Args[0], - c.tokenStorage.ProjectID, - ) - } - } - log.Warnf("\n\nPlease copy this message and create an issue.\n\n%s\n\n", errJSON) - return false, nil - } - return false, err.Error - } - defer func() { - _ = stream.Close() - }() - - // We only need to know if the request was successful, so we can drain the stream. - scanner := bufio.NewScanner(stream) - for scanner.Scan() { - // Do nothing, just consume the stream. - } - - return scanner.Err() == nil, scanner.Err() -} - -// GetProjectList fetches a list of Google Cloud projects accessible by the user. -func (c *Client) GetProjectList(ctx context.Context) (*GCPProject, error) { - token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token() - if err != nil { - return nil, fmt.Errorf("failed to get token: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, "GET", "https://cloudresourcemanager.googleapis.com/v1/projects", nil) - if err != nil { - return nil, fmt.Errorf("could not create project list request: %v", err) - } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to execute project list request: %w", err) - } - defer func() { - _ = resp.Body.Close() - }() - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - bodyBytes, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("project list request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - } - - var project GCPProject - if err = json.NewDecoder(resp.Body).Decode(&project); err != nil { - return nil, fmt.Errorf("failed to unmarshal project list: %w", err) - } - return &project, nil -} - -// SaveTokenToFile serializes the client's current token storage to a JSON file. -// The filename is constructed from the user's email and project ID. -func (c *Client) SaveTokenToFile() error { - if err := os.MkdirAll(c.cfg.AuthDir, 0700); err != nil { - return fmt.Errorf("failed to create directory: %v", err) - } - - fileName := filepath.Join(c.cfg.AuthDir, fmt.Sprintf("%s-%s.json", c.tokenStorage.Email, c.tokenStorage.ProjectID)) - log.Infof("Saving credentials to %s", fileName) - f, err := os.Create(fileName) - if err != nil { - return fmt.Errorf("failed to create token file: %w", err) - } - defer func() { - _ = f.Close() - }() - - if err = json.NewEncoder(f).Encode(c.tokenStorage); err != nil { - return fmt.Errorf("failed to write token to file: %w", err) - } - return nil -} - -// getClientMetadata returns a map of metadata about the client environment, -// such as IDE type, platform, and plugin version. -func getClientMetadata() map[string]string { - return map[string]string{ - "ideType": "IDE_UNSPECIFIED", - "platform": "PLATFORM_UNSPECIFIED", - "pluginType": "GEMINI", - // "pluginVersion": pluginVersion, - } -} - -// getClientMetadataString returns the client metadata as a single, -// comma-separated string, which is required for the 'Client-Metadata' header. -func getClientMetadataString() string { - md := getClientMetadata() - parts := make([]string, 0, len(md)) - for k, v := range md { - parts = append(parts, fmt.Sprintf("%s=%s", k, v)) - } - return strings.Join(parts, ",") -} - -// getUserAgent constructs the User-Agent string for HTTP requests. -func getUserAgent() string { - // return fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH) - return "google-api-nodejs-client/9.15.1" -} - -// getPlatform determines the operating system and architecture and formats -// it into a string expected by the backend API. -func getPlatform() string { - goOS := runtime.GOOS - arch := runtime.GOARCH - switch goOS { - case "darwin": - return fmt.Sprintf("DARWIN_%s", strings.ToUpper(arch)) - case "linux": - return fmt.Sprintf("LINUX_%s", strings.ToUpper(arch)) - case "windows": - return fmt.Sprintf("WINDOWS_%s", strings.ToUpper(arch)) - default: - return "PLATFORM_UNSPECIFIED" - } +// GetRequestMutex returns the mutex used to synchronize requests for this client. +// This ensures that only one request is processed at a time for quota management. +func (c *ClientBase) GetRequestMutex() *sync.Mutex { + return c.RequestMutex } diff --git a/internal/client/client_models.go b/internal/client/client_models.go new file mode 100644 index 00000000..0b64efab --- /dev/null +++ b/internal/client/client_models.go @@ -0,0 +1,159 @@ +// Package client defines the data structures used across all AI API clients. +// These structures represent the common data models for requests, responses, +// and configuration parameters used when communicating with various AI services. +package client + +import "time" + +// ErrorMessage encapsulates an error with an associated HTTP status code. +// This structure is used to provide detailed error information including +// both the HTTP status and the underlying error. +type ErrorMessage struct { + // StatusCode is the HTTP status code returned by the API. + StatusCode int + + // Error is the underlying error that occurred. + Error error +} + +// GCPProject represents the response structure for a Google Cloud project list request. +// This structure is used when fetching available projects for a Google Cloud account. +type GCPProject struct { + // Projects is a list of Google Cloud projects accessible by the user. + Projects []GCPProjectProjects `json:"projects"` +} + +// GCPProjectLabels defines the labels associated with a GCP project. +// These labels can contain metadata about the project's purpose or configuration. +type GCPProjectLabels struct { + // GenerativeLanguage indicates if the project has generative language APIs enabled. + GenerativeLanguage string `json:"generative-language"` +} + +// GCPProjectProjects contains details about a single Google Cloud project. +// This includes identifying information, metadata, and configuration details. +type GCPProjectProjects struct { + // ProjectNumber is the unique numeric identifier for the project. + ProjectNumber string `json:"projectNumber"` + + // ProjectID is the unique string identifier for the project. + ProjectID string `json:"projectId"` + + // LifecycleState indicates the current state of the project (e.g., "ACTIVE"). + LifecycleState string `json:"lifecycleState"` + + // Name is the human-readable name of the project. + Name string `json:"name"` + + // Labels contains metadata labels associated with the project. + Labels GCPProjectLabels `json:"labels"` + + // CreateTime is the timestamp when the project was created. + CreateTime time.Time `json:"createTime"` +} + +// Content represents a single message in a conversation, with a role and parts. +// This structure models a message exchange between a user and an AI model. +type Content struct { + // Role indicates who sent the message ("user", "model", or "tool"). + Role string `json:"role"` + + // Parts is a collection of content parts that make up the message. + Parts []Part `json:"parts"` +} + +// Part represents a distinct piece of content within a message. +// A part can be text, inline data (like an image), a function call, or a function response. +type Part struct { + // Text contains plain text content. + Text string `json:"text,omitempty"` + + // InlineData contains base64-encoded data with its MIME type (e.g., images). + InlineData *InlineData `json:"inlineData,omitempty"` + + // FunctionCall represents a tool call requested by the model. + FunctionCall *FunctionCall `json:"functionCall,omitempty"` + + // FunctionResponse represents the result of a tool execution. + FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` +} + +// InlineData represents base64-encoded data with its MIME type. +// This is typically used for embedding images or other binary data in requests. +type InlineData struct { + // MimeType specifies the media type of the embedded data (e.g., "image/png"). + MimeType string `json:"mime_type,omitempty"` + + // Data contains the base64-encoded binary data. + Data string `json:"data,omitempty"` +} + +// FunctionCall represents a tool call requested by the model. +// It includes the function name and its arguments that the model wants to execute. +type FunctionCall struct { + // Name is the identifier of the function to be called. + Name string `json:"name"` + + // Args contains the arguments to pass to the function. + Args map[string]interface{} `json:"args"` +} + +// FunctionResponse represents the result of a tool execution. +// This is sent back to the model after a tool call has been processed. +type FunctionResponse struct { + // Name is the identifier of the function that was called. + Name string `json:"name"` + + // Response contains the result data from the function execution. + Response map[string]interface{} `json:"response"` +} + +// GenerateContentRequest is the top-level request structure for the streamGenerateContent endpoint. +// This structure defines all the parameters needed for generating content from an AI model. +type GenerateContentRequest struct { + // SystemInstruction provides system-level instructions that guide the model's behavior. + SystemInstruction *Content `json:"systemInstruction,omitempty"` + + // Contents is the conversation history between the user and the model. + Contents []Content `json:"contents"` + + // Tools defines the available tools/functions that the model can call. + Tools []ToolDeclaration `json:"tools,omitempty"` + + // GenerationConfig contains parameters that control the model's generation behavior. + GenerationConfig `json:"generationConfig"` +} + +// GenerationConfig defines parameters that control the model's generation behavior. +// These parameters affect the creativity, randomness, and reasoning of the model's responses. +type GenerationConfig struct { + // ThinkingConfig specifies configuration for the model's "thinking" process. + ThinkingConfig GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"` + + // Temperature controls the randomness of the model's responses. + // Values closer to 0 make responses more deterministic, while values closer to 1 increase randomness. + Temperature float64 `json:"temperature,omitempty"` + + // TopP controls nucleus sampling, which affects the diversity of responses. + // It limits the model to consider only the top P% of probability mass. + TopP float64 `json:"topP,omitempty"` + + // TopK limits the model to consider only the top K most likely tokens. + // This can help control the quality and diversity of generated text. + TopK float64 `json:"topK,omitempty"` +} + +// GenerationConfigThinkingConfig specifies configuration for the model's "thinking" process. +// This controls whether the model should output its reasoning process along with the final answer. +type GenerationConfigThinkingConfig struct { + // IncludeThoughts determines whether the model should output its reasoning process. + // When enabled, the model will include its step-by-step thinking in the response. + IncludeThoughts bool `json:"include_thoughts,omitempty"` +} + +// ToolDeclaration defines the structure for declaring tools (like functions) +// that the model can call during content generation. +type ToolDeclaration struct { + // FunctionDeclarations is a list of available functions that the model can call. + FunctionDeclarations []interface{} `json:"functionDeclarations"` +} diff --git a/internal/client/codex_client.go b/internal/client/codex_client.go new file mode 100644 index 00000000..22d4ddaf --- /dev/null +++ b/internal/client/codex_client.go @@ -0,0 +1,258 @@ +package client + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "path/filepath" + "sync" + "time" + + "github.com/google/uuid" + "github.com/luispater/CLIProxyAPI/internal/auth" + "github.com/luispater/CLIProxyAPI/internal/auth/codex" + "github.com/luispater/CLIProxyAPI/internal/config" + "github.com/luispater/CLIProxyAPI/internal/util" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" +) + +const ( + chatGPTEndpoint = "https://chatgpt.com/backend-api" +) + +// CodexClient implements the Client interface for OpenAI API +type CodexClient struct { + ClientBase + codexAuth *codex.CodexAuth +} + +// NewCodexClient creates a new OpenAI client instance +func NewCodexClient(cfg *config.Config, ts *codex.CodexTokenStorage) (*CodexClient, error) { + httpClient := util.SetProxy(cfg, &http.Client{}) + client := &CodexClient{ + ClientBase: ClientBase{ + RequestMutex: &sync.Mutex{}, + httpClient: httpClient, + cfg: cfg, + modelQuotaExceeded: make(map[string]*time.Time), + tokenStorage: ts, + }, + codexAuth: codex.NewCodexAuth(cfg), + } + + return client, nil +} + +// GetUserAgent returns the user agent string for OpenAI API requests +func (c *CodexClient) GetUserAgent() string { + return "codex-cli" +} + +func (c *CodexClient) TokenStorage() auth.TokenStorage { + return c.tokenStorage +} + +// SendMessage sends a message to OpenAI API (non-streaming) +func (c *CodexClient) SendMessage(_ context.Context, _ []byte, _ string, _ *Content, _ []Content, _ []ToolDeclaration) ([]byte, *ErrorMessage) { + // For now, return an error as OpenAI integration is not fully implemented + return nil, &ErrorMessage{ + StatusCode: http.StatusNotImplemented, + Error: fmt.Errorf("codex message sending not yet implemented"), + } +} + +// SendMessageStream sends a streaming message to OpenAI API +func (c *CodexClient) SendMessageStream(_ context.Context, _ []byte, _ string, _ *Content, _ []Content, _ []ToolDeclaration, _ ...bool) (<-chan []byte, <-chan *ErrorMessage) { + errChan := make(chan *ErrorMessage, 1) + errChan <- &ErrorMessage{ + StatusCode: http.StatusNotImplemented, + Error: fmt.Errorf("codex streaming not yet implemented"), + } + close(errChan) + + return nil, errChan +} + +// SendRawMessage sends a raw message to OpenAI API +func (c *CodexClient) SendRawMessage(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) { + modelResult := gjson.GetBytes(rawJSON, "model") + model := modelResult.String() + modelName := model + + respBody, err := c.APIRequest(ctx, "/codex/responses", rawJSON, alt, false) + if err != nil { + if err.StatusCode == 429 { + now := time.Now() + c.modelQuotaExceeded[modelName] = &now + } + return nil, err + } + delete(c.modelQuotaExceeded, modelName) + bodyBytes, errReadAll := io.ReadAll(respBody) + if errReadAll != nil { + return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll} + } + return bodyBytes, nil + +} + +// SendRawMessageStream sends a raw streaming message to OpenAI API +func (c *CodexClient) SendRawMessageStream(ctx context.Context, rawJSON []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) { + errChan := make(chan *ErrorMessage) + dataChan := make(chan []byte) + go func() { + defer close(errChan) + defer close(dataChan) + + modelResult := gjson.GetBytes(rawJSON, "model") + model := modelResult.String() + modelName := model + var stream io.ReadCloser + for { + var err *ErrorMessage + stream, err = c.APIRequest(ctx, "/codex/responses", rawJSON, alt, true) + if err != nil { + if err.StatusCode == 429 { + now := time.Now() + c.modelQuotaExceeded[modelName] = &now + } + errChan <- err + return + } + delete(c.modelQuotaExceeded, modelName) + break + } + + scanner := bufio.NewScanner(stream) + buffer := make([]byte, 10240*1024) + scanner.Buffer(buffer, 10240*1024) + for scanner.Scan() { + line := scanner.Bytes() + dataChan <- line + } + + if errScanner := scanner.Err(); errScanner != nil { + errChan <- &ErrorMessage{500, errScanner} + _ = stream.Close() + return + } + + _ = stream.Close() + }() + + return dataChan, errChan +} + +// SendRawTokenCount sends a token count request to OpenAI API +func (c *CodexClient) SendRawTokenCount(_ context.Context, _ []byte, _ string) ([]byte, *ErrorMessage) { + return nil, &ErrorMessage{ + StatusCode: http.StatusNotImplemented, + Error: fmt.Errorf("codex token counting not yet implemented"), + } +} + +// SaveTokenToFile persists the token storage to disk +func (c *CodexClient) SaveTokenToFile() error { + fileName := filepath.Join(c.cfg.AuthDir, fmt.Sprintf("codex-%s.json", c.tokenStorage.(*codex.CodexTokenStorage).Email)) + return c.tokenStorage.SaveTokenToFile(fileName) +} + +// RefreshTokens refreshes the access tokens if needed +func (c *CodexClient) RefreshTokens(ctx context.Context) error { + if c.tokenStorage == nil || c.tokenStorage.(*codex.CodexTokenStorage).RefreshToken == "" { + return fmt.Errorf("no refresh token available") + } + + // Refresh tokens using the auth service + newTokenData, err := c.codexAuth.RefreshTokensWithRetry(ctx, c.tokenStorage.(*codex.CodexTokenStorage).RefreshToken, 3) + if err != nil { + return fmt.Errorf("failed to refresh tokens: %w", err) + } + + // Update token storage + c.codexAuth.UpdateTokenStorage(c.tokenStorage.(*codex.CodexTokenStorage), newTokenData) + + // Save updated tokens + if err = c.SaveTokenToFile(); err != nil { + log.Warnf("Failed to save refreshed tokens: %v", err) + } + + log.Debug("codex tokens refreshed successfully") + return nil +} + +// APIRequest handles making requests to the CLI API endpoints. +func (c *CodexClient) APIRequest(ctx context.Context, endpoint string, body interface{}, _ string, _ bool) (io.ReadCloser, *ErrorMessage) { + var jsonBody []byte + var err error + if byteBody, ok := body.([]byte); ok { + jsonBody = byteBody + } else { + jsonBody, err = json.Marshal(body) + if err != nil { + return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err)} + } + } + + url := fmt.Sprintf("%s/%s", chatGPTEndpoint, endpoint) + + // log.Debug(string(jsonBody)) + // log.Debug(url) + reqBody := bytes.NewBuffer(jsonBody) + + req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody) + if err != nil { + return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %v", err)} + } + + sessionID := uuid.New().String() + // Set headers + req.Header.Set("Version", "0.21.0") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Openai-Beta", "responses=experimental") + req.Header.Set("Session_id", sessionID) + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Chatgpt-Account-Id", c.tokenStorage.(*codex.CodexTokenStorage).AccountID) + req.Header.Set("Originator", "codex_cli_rs") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.tokenStorage.(*codex.CodexTokenStorage).AccessToken)) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %v", err)} + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + defer func() { + if err = resp.Body.Close(); err != nil { + log.Printf("warn: failed to close response body: %v", err) + } + }() + bodyBytes, _ := io.ReadAll(resp.Body) + // log.Debug(string(jsonBody)) + return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))} + } + + return resp.Body, nil +} + +func (c *CodexClient) GetEmail() string { + return c.tokenStorage.(*codex.CodexTokenStorage).Email +} + +// IsModelQuotaExceeded returns true if the specified model has exceeded its quota +// and no fallback options are available. +func (c *CodexClient) IsModelQuotaExceeded(model string) bool { + if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey { + duration := time.Now().Sub(*lastExceededTime) + if duration > 30*time.Minute { + return false + } + return true + } + return false +} diff --git a/internal/client/gemini_client.go b/internal/client/gemini_client.go new file mode 100644 index 00000000..2d2cbb63 --- /dev/null +++ b/internal/client/gemini_client.go @@ -0,0 +1,942 @@ +// Package client provides HTTP client functionality for interacting with Google Cloud AI APIs. +// It handles OAuth2 authentication, token management, request/response processing, +// streaming communication, quota management, and automatic model fallback. +// The package supports both direct API key authentication and OAuth2 flows. +package client + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" + + geminiAuth "github.com/luispater/CLIProxyAPI/internal/auth/gemini" + "github.com/luispater/CLIProxyAPI/internal/config" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "golang.org/x/oauth2" +) + +const ( + codeAssistEndpoint = "https://cloudcode-pa.googleapis.com" + apiVersion = "v1internal" + + glEndPoint = "https://generativelanguage.googleapis.com" + glAPIVersion = "v1beta" +) + +var ( + previewModels = map[string][]string{ + "gemini-2.5-pro": {"gemini-2.5-pro-preview-05-06", "gemini-2.5-pro-preview-06-05"}, + "gemini-2.5-flash": {"gemini-2.5-flash-preview-04-17", "gemini-2.5-flash-preview-05-20"}, + } +) + +// GeminiClient is the main client for interacting with the CLI API. +type GeminiClient struct { + ClientBase + glAPIKey string +} + +// NewGeminiClient creates a new CLI API client. +func NewGeminiClient(httpClient *http.Client, ts *geminiAuth.GeminiTokenStorage, cfg *config.Config, glAPIKey ...string) *GeminiClient { + var glKey string + if len(glAPIKey) > 0 { + glKey = glAPIKey[0] + } + return &GeminiClient{ + ClientBase: ClientBase{ + RequestMutex: &sync.Mutex{}, + httpClient: httpClient, + cfg: cfg, + tokenStorage: ts, + modelQuotaExceeded: make(map[string]*time.Time), + }, + glAPIKey: glKey, + } +} + +// SetProjectID updates the project ID for the client's token storage. +func (c *GeminiClient) SetProjectID(projectID string) { + c.tokenStorage.(*geminiAuth.GeminiTokenStorage).ProjectID = projectID +} + +// SetIsAuto configures whether the client should operate in automatic mode. +func (c *GeminiClient) SetIsAuto(auto bool) { + c.tokenStorage.(*geminiAuth.GeminiTokenStorage).Auto = auto +} + +// SetIsChecked sets the checked status for the client's token storage. +func (c *GeminiClient) SetIsChecked(checked bool) { + c.tokenStorage.(*geminiAuth.GeminiTokenStorage).Checked = checked +} + +// IsChecked returns whether the client's token storage has been checked. +func (c *GeminiClient) IsChecked() bool { + return c.tokenStorage.(*geminiAuth.GeminiTokenStorage).Checked +} + +// IsAuto returns whether the client is operating in automatic mode. +func (c *GeminiClient) IsAuto() bool { + return c.tokenStorage.(*geminiAuth.GeminiTokenStorage).Auto +} + +// GetEmail returns the email address associated with the client's token storage. +func (c *GeminiClient) GetEmail() string { + return c.tokenStorage.(*geminiAuth.GeminiTokenStorage).Email +} + +// GetProjectID returns the Google Cloud project ID from the client's token storage. +func (c *GeminiClient) GetProjectID() string { + if c.glAPIKey == "" && c.tokenStorage != nil { + if ts, ok := c.tokenStorage.(*geminiAuth.GeminiTokenStorage); ok { + return ts.ProjectID + } + } + return "" +} + +// GetGenerativeLanguageAPIKey returns the generative language API key if configured. +func (c *GeminiClient) GetGenerativeLanguageAPIKey() string { + return c.glAPIKey +} + +// SetupUser performs the initial user onboarding and setup. +func (c *GeminiClient) SetupUser(ctx context.Context, email, projectID string) error { + c.tokenStorage.(*geminiAuth.GeminiTokenStorage).Email = email + log.Info("Performing user onboarding...") + + // 1. LoadCodeAssist + loadAssistReqBody := map[string]interface{}{ + "metadata": c.getClientMetadata(), + } + if projectID != "" { + loadAssistReqBody["cloudaicompanionProject"] = projectID + } + + var loadAssistResp map[string]interface{} + err := c.makeAPIRequest(ctx, "loadCodeAssist", "POST", loadAssistReqBody, &loadAssistResp) + if err != nil { + return fmt.Errorf("failed to load code assist: %w", err) + } + + // a, _ := json.Marshal(&loadAssistResp) + // log.Debug(string(a)) + // + // a, _ = json.Marshal(loadAssistReqBody) + // log.Debug(string(a)) + + // 2. OnboardUser + var onboardTierID = "legacy-tier" + if tiers, ok := loadAssistResp["allowedTiers"].([]interface{}); ok { + for _, t := range tiers { + if tier, tierOk := t.(map[string]interface{}); tierOk { + if isDefault, isDefaultOk := tier["isDefault"].(bool); isDefaultOk && isDefault { + if id, idOk := tier["id"].(string); idOk { + onboardTierID = id + break + } + } + } + } + } + + onboardProjectID := projectID + if p, ok := loadAssistResp["cloudaicompanionProject"].(string); ok && p != "" { + onboardProjectID = p + } + + onboardReqBody := map[string]interface{}{ + "tierId": onboardTierID, + "metadata": c.getClientMetadata(), + } + if onboardProjectID != "" { + onboardReqBody["cloudaicompanionProject"] = onboardProjectID + } else { + return fmt.Errorf("failed to start user onboarding, need define a project id") + } + + for { + var lroResp map[string]interface{} + err = c.makeAPIRequest(ctx, "onboardUser", "POST", onboardReqBody, &lroResp) + if err != nil { + return fmt.Errorf("failed to start user onboarding: %w", err) + } + // a, _ := json.Marshal(&lroResp) + // log.Debug(string(a)) + + // 3. Poll Long-Running Operation (LRO) + done, doneOk := lroResp["done"].(bool) + if doneOk && done { + if project, projectOk := lroResp["response"].(map[string]interface{})["cloudaicompanionProject"].(map[string]interface{}); projectOk { + if projectID != "" { + c.tokenStorage.(*geminiAuth.GeminiTokenStorage).ProjectID = projectID + } else { + c.tokenStorage.(*geminiAuth.GeminiTokenStorage).ProjectID = project["id"].(string) + } + log.Infof("Onboarding complete. Using Project ID: %s", c.tokenStorage.(*geminiAuth.GeminiTokenStorage).ProjectID) + return nil + } + } else { + log.Println("Onboarding in progress, waiting 5 seconds...") + time.Sleep(5 * time.Second) + } + } +} + +// makeAPIRequest handles making requests to the CLI API endpoints. +func (c *GeminiClient) makeAPIRequest(ctx context.Context, endpoint, method string, body interface{}, result interface{}) error { + var reqBody io.Reader + if body != nil { + jsonBody, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("failed to marshal request body: %w", err) + } + reqBody = bytes.NewBuffer(jsonBody) + } + + url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint) + if strings.HasPrefix(endpoint, "operations/") { + url = fmt.Sprintf("%s/%s", codeAssistEndpoint, endpoint) + } + + req, err := http.NewRequestWithContext(ctx, method, url, reqBody) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token() + if err != nil { + return fmt.Errorf("failed to get token: %w", err) + } + + // Set headers + metadataStr := c.getClientMetadataString() + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", c.GetUserAgent()) + req.Header.Set("X-Goog-Api-Client", "gl-node/22.17.0") + req.Header.Set("Client-Metadata", metadataStr) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to execute request: %w", err) + } + defer func() { + if err = resp.Body.Close(); err != nil { + log.Printf("warn: failed to close response body: %v", err) + } + }() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + bodyBytes, _ := io.ReadAll(resp.Body) + return fmt.Errorf("api request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) + } + + if result != nil { + if err = json.NewDecoder(resp.Body).Decode(result); err != nil { + return fmt.Errorf("failed to decode response body: %w", err) + } + } + + return nil +} + +// APIRequest handles making requests to the CLI API endpoints. +func (c *GeminiClient) APIRequest(ctx context.Context, endpoint string, body interface{}, alt string, stream bool) (io.ReadCloser, *ErrorMessage) { + var jsonBody []byte + var err error + if byteBody, ok := body.([]byte); ok { + jsonBody = byteBody + } else { + jsonBody, err = json.Marshal(body) + if err != nil { + return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err)} + } + } + + var url string + if c.glAPIKey == "" { + // Add alt=sse for streaming + url = fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint) + if alt == "" && stream { + url = url + "?alt=sse" + } else { + if alt != "" { + url = url + fmt.Sprintf("?$alt=%s", alt) + } + } + } else { + if endpoint == "countTokens" { + modelResult := gjson.GetBytes(jsonBody, "model") + url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glAPIVersion, modelResult.String(), endpoint) + } else { + modelResult := gjson.GetBytes(jsonBody, "model") + url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glAPIVersion, modelResult.String(), endpoint) + if alt == "" && stream { + url = url + "?alt=sse" + } else { + if alt != "" { + url = url + fmt.Sprintf("?$alt=%s", alt) + } + } + jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw) + systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction") + if systemInstructionResult.Exists() { + jsonBody, _ = sjson.SetRawBytes(jsonBody, "system_instruction", []byte(systemInstructionResult.Raw)) + jsonBody, _ = sjson.DeleteBytes(jsonBody, "systemInstruction") + jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id") + } + } + } + + // log.Debug(string(jsonBody)) + // log.Debug(url) + reqBody := bytes.NewBuffer(jsonBody) + + req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody) + if err != nil { + return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %v", err)} + } + + // Set headers + metadataStr := c.getClientMetadataString() + req.Header.Set("Content-Type", "application/json") + if c.glAPIKey == "" { + token, errToken := c.httpClient.Transport.(*oauth2.Transport).Source.Token() + if errToken != nil { + return nil, &ErrorMessage{500, fmt.Errorf("failed to get token: %v", errToken)} + } + req.Header.Set("User-Agent", c.GetUserAgent()) + req.Header.Set("X-Goog-Api-Client", "gl-node/22.17.0") + req.Header.Set("Client-Metadata", metadataStr) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) + } else { + req.Header.Set("x-goog-api-key", c.glAPIKey) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %v", err)} + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + defer func() { + if err = resp.Body.Close(); err != nil { + log.Printf("warn: failed to close response body: %v", err) + } + }() + bodyBytes, _ := io.ReadAll(resp.Body) + // log.Debug(string(jsonBody)) + return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))} + } + + return resp.Body, nil +} + +// SendMessage handles a single conversational turn, including tool calls. +func (c *GeminiClient) SendMessage(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) { + request := GenerateContentRequest{ + Contents: contents, + GenerationConfig: GenerationConfig{ + ThinkingConfig: GenerationConfigThinkingConfig{ + IncludeThoughts: true, + }, + }, + } + + request.SystemInstruction = systemInstruction + + request.Tools = tools + + requestBody := map[string]interface{}{ + "project": c.GetProjectID(), // Assuming ProjectID is available + "request": request, + "model": model, + } + + byteRequestBody, _ := json.Marshal(requestBody) + + // log.Debug(string(byteRequestBody)) + + reasoningEffortResult := gjson.GetBytes(rawJSON, "reasoning_effort") + if reasoningEffortResult.String() == "none" { + byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts") + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0) + } else if reasoningEffortResult.String() == "auto" { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) + } else if reasoningEffortResult.String() == "low" { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024) + } else if reasoningEffortResult.String() == "medium" { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192) + } else if reasoningEffortResult.String() == "high" { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576) + } else { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) + } + + temperatureResult := gjson.GetBytes(rawJSON, "temperature") + if temperatureResult.Exists() && temperatureResult.Type == gjson.Number { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num) + } + + topPResult := gjson.GetBytes(rawJSON, "top_p") + if topPResult.Exists() && topPResult.Type == gjson.Number { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num) + } + + topKResult := gjson.GetBytes(rawJSON, "top_k") + if topKResult.Exists() && topKResult.Type == gjson.Number { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num) + } + + modelName := model + // log.Debug(string(byteRequestBody)) + for { + if c.isModelQuotaExceeded(modelName) { + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { + modelName = c.getPreviewModel(model) + if modelName != "" { + log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName) + continue + } + } + return nil, &ErrorMessage{ + StatusCode: 429, + Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model), + } + } + + respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, "", false) + if err != nil { + if err.StatusCode == 429 { + now := time.Now() + c.modelQuotaExceeded[modelName] = &now + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { + continue + } + } + return nil, err + } + delete(c.modelQuotaExceeded, modelName) + bodyBytes, errReadAll := io.ReadAll(respBody) + if errReadAll != nil { + return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll} + } + return bodyBytes, nil + } +} + +// SendMessageStream handles streaming conversational turns with comprehensive parameter management. +// This function implements a sophisticated streaming system that supports tool calls, reasoning modes, +// quota management, and automatic model fallback. It returns two channels for asynchronous communication: +// one for streaming response data and another for error handling. +func (c *GeminiClient) SendMessageStream(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration, includeThoughts ...bool) (<-chan []byte, <-chan *ErrorMessage) { + // Define the data prefix used in Server-Sent Events streaming format + dataTag := []byte("data: ") + + // Create channels for asynchronous communication + // errChan: delivers error messages during streaming + // dataChan: delivers response data chunks + errChan := make(chan *ErrorMessage) + dataChan := make(chan []byte) + + // Launch a goroutine to handle the streaming process asynchronously + // This allows the function to return immediately while processing continues in the background + go func() { + // Ensure channels are properly closed when the goroutine exits + defer close(errChan) + defer close(dataChan) + + // Configure thinking/reasoning capabilities + // Default to including thoughts unless explicitly disabled + includeThoughtsFlag := true + if len(includeThoughts) > 0 { + includeThoughtsFlag = includeThoughts[0] + } + + // Build the base request structure for the Gemini API + // This includes conversation contents and generation configuration + request := GenerateContentRequest{ + Contents: contents, + GenerationConfig: GenerationConfig{ + ThinkingConfig: GenerationConfigThinkingConfig{ + IncludeThoughts: includeThoughtsFlag, + }, + }, + } + + // Add system instructions if provided + // System instructions guide the AI's behavior and response style + request.SystemInstruction = systemInstruction + + // Add available tools for function calling capabilities + // Tools allow the AI to perform actions beyond text generation + request.Tools = tools + + // Construct the complete request body with project context + // The project ID is essential for proper API routing and billing + requestBody := map[string]interface{}{ + "project": c.GetProjectID(), // Project ID for API routing and quota management + "request": request, + "model": model, + } + + // Serialize the request body to JSON for API transmission + byteRequestBody, _ := json.Marshal(requestBody) + + // Parse and configure reasoning effort levels from the original request + // This maps Claude-style reasoning effort parameters to Gemini's thinking budget system + reasoningEffortResult := gjson.GetBytes(rawJSON, "reasoning_effort") + if reasoningEffortResult.String() == "none" { + // Disable thinking entirely for fastest responses + byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts") + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0) + } else if reasoningEffortResult.String() == "auto" { + // Let the model decide the appropriate thinking budget automatically + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) + } else if reasoningEffortResult.String() == "low" { + // Minimal thinking for simple tasks (1KB thinking budget) + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024) + } else if reasoningEffortResult.String() == "medium" { + // Moderate thinking for complex tasks (8KB thinking budget) + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192) + } else if reasoningEffortResult.String() == "high" { + // Maximum thinking for very complex tasks (24KB thinking budget) + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576) + } else { + // Default to automatic thinking budget if no specific level is provided + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) + } + + // Configure temperature parameter for response randomness control + // Temperature affects the creativity vs consistency trade-off in responses + temperatureResult := gjson.GetBytes(rawJSON, "temperature") + if temperatureResult.Exists() && temperatureResult.Type == gjson.Number { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num) + } + + // Configure top-p parameter for nucleus sampling + // Controls the cumulative probability threshold for token selection + topPResult := gjson.GetBytes(rawJSON, "top_p") + if topPResult.Exists() && topPResult.Type == gjson.Number { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num) + } + + // Configure top-k parameter for limiting token candidates + // Restricts the model to consider only the top K most likely tokens + topKResult := gjson.GetBytes(rawJSON, "top_k") + if topKResult.Exists() && topKResult.Type == gjson.Number { + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num) + } + + // Initialize model name for quota management and potential fallback + modelName := model + var stream io.ReadCloser + + // Quota management and model fallback loop + // This loop handles quota exceeded scenarios and automatic model switching + for { + // Check if the current model has exceeded its quota + if c.isModelQuotaExceeded(modelName) { + // Attempt to switch to a preview model if configured and using account auth + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { + modelName = c.getPreviewModel(model) + if modelName != "" { + log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) + // Update the request body with the new model name + byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName) + continue // Retry with the preview model + } + } + // If no fallback is available, return a quota exceeded error + errChan <- &ErrorMessage{ + StatusCode: 429, + Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model), + } + return + } + + // Attempt to establish a streaming connection with the API + var err *ErrorMessage + stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, "", true) + if err != nil { + // Handle quota exceeded errors by marking the model and potentially retrying + if err.StatusCode == 429 { + now := time.Now() + c.modelQuotaExceeded[modelName] = &now // Mark model as quota exceeded + // If preview model switching is enabled, retry the loop + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { + continue + } + } + // Forward other errors to the error channel + errChan <- err + return + } + // Clear any previous quota exceeded status for this model + delete(c.modelQuotaExceeded, modelName) + break // Successfully established connection, exit the retry loop + } + + // Process the streaming response using a scanner + // This handles the Server-Sent Events format from the API + scanner := bufio.NewScanner(stream) + for scanner.Scan() { + line := scanner.Bytes() + // Filter and forward only data lines (those prefixed with "data: ") + // This extracts the actual JSON content from the SSE format + if bytes.HasPrefix(line, dataTag) { + dataChan <- line[6:] // Remove "data: " prefix and send the JSON content + } + } + + // Handle any scanning errors that occurred during stream processing + if errScanner := scanner.Err(); errScanner != nil { + // Send a 500 Internal Server Error for scanning failures + errChan <- &ErrorMessage{500, errScanner} + _ = stream.Close() + return + } + + // Ensure the stream is properly closed to prevent resource leaks + _ = stream.Close() + }() + + // Return the channels immediately for asynchronous communication + // The caller can read from these channels while the goroutine processes the request + return dataChan, errChan +} + +// SendRawTokenCount handles a token count. +func (c *GeminiClient) SendRawTokenCount(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) { + modelResult := gjson.GetBytes(rawJSON, "model") + model := modelResult.String() + modelName := model + for { + if c.isModelQuotaExceeded(modelName) { + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { + modelName = c.getPreviewModel(model) + if modelName != "" { + log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) + rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName) + continue + } + } + return nil, &ErrorMessage{ + StatusCode: 429, + Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model), + } + } + + respBody, err := c.APIRequest(ctx, "countTokens", rawJSON, alt, false) + if err != nil { + if err.StatusCode == 429 { + now := time.Now() + c.modelQuotaExceeded[modelName] = &now + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { + continue + } + } + return nil, err + } + delete(c.modelQuotaExceeded, modelName) + bodyBytes, errReadAll := io.ReadAll(respBody) + if errReadAll != nil { + return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll} + } + return bodyBytes, nil + } +} + +// SendRawMessage handles a single conversational turn, including tool calls. +func (c *GeminiClient) SendRawMessage(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) { + if c.glAPIKey == "" { + rawJSON, _ = sjson.SetBytes(rawJSON, "project", c.GetProjectID()) + } + + modelResult := gjson.GetBytes(rawJSON, "model") + model := modelResult.String() + modelName := model + for { + if c.isModelQuotaExceeded(modelName) { + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { + modelName = c.getPreviewModel(model) + if modelName != "" { + log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) + rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName) + continue + } + } + return nil, &ErrorMessage{ + StatusCode: 429, + Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model), + } + } + + respBody, err := c.APIRequest(ctx, "generateContent", rawJSON, alt, false) + if err != nil { + if err.StatusCode == 429 { + now := time.Now() + c.modelQuotaExceeded[modelName] = &now + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { + continue + } + } + return nil, err + } + delete(c.modelQuotaExceeded, modelName) + bodyBytes, errReadAll := io.ReadAll(respBody) + if errReadAll != nil { + return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll} + } + return bodyBytes, nil + } +} + +// SendRawMessageStream handles a single conversational turn, including tool calls. +func (c *GeminiClient) SendRawMessageStream(ctx context.Context, rawJSON []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) { + dataTag := []byte("data: ") + errChan := make(chan *ErrorMessage) + dataChan := make(chan []byte) + go func() { + defer close(errChan) + defer close(dataChan) + + if c.glAPIKey == "" { + rawJSON, _ = sjson.SetBytes(rawJSON, "project", c.GetProjectID()) + } + + modelResult := gjson.GetBytes(rawJSON, "model") + model := modelResult.String() + modelName := model + var stream io.ReadCloser + for { + if c.isModelQuotaExceeded(modelName) { + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { + modelName = c.getPreviewModel(model) + if modelName != "" { + log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) + rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName) + continue + } + } + errChan <- &ErrorMessage{ + StatusCode: 429, + Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model), + } + return + } + var err *ErrorMessage + stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJSON, alt, true) + if err != nil { + if err.StatusCode == 429 { + now := time.Now() + c.modelQuotaExceeded[modelName] = &now + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { + continue + } + } + errChan <- err + return + } + delete(c.modelQuotaExceeded, modelName) + break + } + + if alt == "" { + scanner := bufio.NewScanner(stream) + for scanner.Scan() { + line := scanner.Bytes() + if bytes.HasPrefix(line, dataTag) { + dataChan <- line[6:] + } + } + + if errScanner := scanner.Err(); errScanner != nil { + errChan <- &ErrorMessage{500, errScanner} + _ = stream.Close() + return + } + + } else { + data, err := io.ReadAll(stream) + if err != nil { + errChan <- &ErrorMessage{500, err} + _ = stream.Close() + return + } + dataChan <- data + } + _ = stream.Close() + + }() + + return dataChan, errChan +} + +// isModelQuotaExceeded checks if the specified model has exceeded its quota +// within the last 30 minutes. +func (c *GeminiClient) isModelQuotaExceeded(model string) bool { + if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey { + duration := time.Now().Sub(*lastExceededTime) + if duration > 30*time.Minute { + return false + } + return true + } + return false +} + +// getPreviewModel returns an available preview model for the given base model, +// or an empty string if no preview models are available or all are quota exceeded. +func (c *GeminiClient) getPreviewModel(model string) string { + if models, hasKey := previewModels[model]; hasKey { + for i := 0; i < len(models); i++ { + if !c.isModelQuotaExceeded(models[i]) { + return models[i] + } + } + } + return "" +} + +// IsModelQuotaExceeded returns true if the specified model has exceeded its quota +// and no fallback options are available. +func (c *GeminiClient) IsModelQuotaExceeded(model string) bool { + if c.isModelQuotaExceeded(model) { + if c.cfg.QuotaExceeded.SwitchPreviewModel { + return c.getPreviewModel(model) == "" + } + return true + } + return false +} + +// CheckCloudAPIIsEnabled sends a simple test request to the API to verify +// that the Cloud AI API is enabled for the user's project. It provides +// an activation URL if the API is disabled. +func (c *GeminiClient) CheckCloudAPIIsEnabled() (bool, error) { + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + c.RequestMutex.Unlock() + cancel() + }() + c.RequestMutex.Lock() + + // A simple request to test the API endpoint. + requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.(*geminiAuth.GeminiTokenStorage).ProjectID) + + stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), "", true) + if err != nil { + // If a 403 Forbidden error occurs, it likely means the API is not enabled. + if err.StatusCode == 403 { + errJSON := err.Error.Error() + // Check for a specific error code and extract the activation URL. + if gjson.Get(errJSON, "0.error.code").Int() == 403 { + activationURL := gjson.Get(errJSON, "0.error.details.0.metadata.activationUrl").String() + if activationURL != "" { + log.Warnf( + "\n\nPlease activate your account with this url:\n\n%s\n\n And execute this command again:\n%s --login --project_id %s", + activationURL, + os.Args[0], + c.tokenStorage.(*geminiAuth.GeminiTokenStorage).ProjectID, + ) + } + } + log.Warnf("\n\nPlease copy this message and create an issue.\n\n%s\n\n", errJSON) + return false, nil + } + return false, err.Error + } + defer func() { + _ = stream.Close() + }() + + // We only need to know if the request was successful, so we can drain the stream. + scanner := bufio.NewScanner(stream) + for scanner.Scan() { + // Do nothing, just consume the stream. + } + + return scanner.Err() == nil, scanner.Err() +} + +// GetProjectList fetches a list of Google Cloud projects accessible by the user. +func (c *GeminiClient) GetProjectList(ctx context.Context) (*GCPProject, error) { + token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token() + if err != nil { + return nil, fmt.Errorf("failed to get token: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "GET", "https://cloudresourcemanager.googleapis.com/v1/projects", nil) + if err != nil { + return nil, fmt.Errorf("could not create project list request: %v", err) + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to execute project list request: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("project list request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var project GCPProject + if err = json.NewDecoder(resp.Body).Decode(&project); err != nil { + return nil, fmt.Errorf("failed to unmarshal project list: %w", err) + } + return &project, nil +} + +// SaveTokenToFile serializes the client's current token storage to a JSON file. +// The filename is constructed from the user's email and project ID. +func (c *GeminiClient) SaveTokenToFile() error { + fileName := filepath.Join(c.cfg.AuthDir, fmt.Sprintf("%s-%s.json", c.tokenStorage.(*geminiAuth.GeminiTokenStorage).Email, c.tokenStorage.(*geminiAuth.GeminiTokenStorage).ProjectID)) + log.Infof("Saving credentials to %s", fileName) + return c.tokenStorage.SaveTokenToFile(fileName) +} + +// getClientMetadata returns a map of metadata about the client environment, +// such as IDE type, platform, and plugin version. +func (c *GeminiClient) getClientMetadata() map[string]string { + return map[string]string{ + "ideType": "IDE_UNSPECIFIED", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + // "pluginVersion": pluginVersion, + } +} + +// getClientMetadataString returns the client metadata as a single, +// comma-separated string, which is required for the 'GeminiClient-Metadata' header. +func (c *GeminiClient) getClientMetadataString() string { + md := c.getClientMetadata() + parts := make([]string, 0, len(md)) + for k, v := range md { + parts = append(parts, fmt.Sprintf("%s=%s", k, v)) + } + return strings.Join(parts, ",") +} + +// GetUserAgent constructs the User-Agent string for HTTP requests. +func (c *GeminiClient) GetUserAgent() string { + // return fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH) + return "google-api-nodejs-client/9.15.1" +} diff --git a/internal/client/models.go b/internal/client/models.go deleted file mode 100644 index eabc6b59..00000000 --- a/internal/client/models.go +++ /dev/null @@ -1,91 +0,0 @@ -package client - -import "time" - -// ErrorMessage encapsulates an error with an associated HTTP status code. -type ErrorMessage struct { - StatusCode int - Error error -} - -// GCPProject represents the response structure for a Google Cloud project list request. -type GCPProject struct { - Projects []GCPProjectProjects `json:"projects"` -} - -// GCPProjectLabels defines the labels associated with a GCP project. -type GCPProjectLabels struct { - GenerativeLanguage string `json:"generative-language"` -} - -// GCPProjectProjects contains details about a single Google Cloud project. -type GCPProjectProjects struct { - ProjectNumber string `json:"projectNumber"` - ProjectID string `json:"projectId"` - LifecycleState string `json:"lifecycleState"` - Name string `json:"name"` - Labels GCPProjectLabels `json:"labels"` - CreateTime time.Time `json:"createTime"` -} - -// Content represents a single message in a conversation, with a role and parts. -type Content struct { - Role string `json:"role"` - Parts []Part `json:"parts"` -} - -// Part represents a distinct piece of content within a message, which can be -// text, inline data (like an image), a function call, or a function response. -type Part struct { - Text string `json:"text,omitempty"` - InlineData *InlineData `json:"inlineData,omitempty"` - FunctionCall *FunctionCall `json:"functionCall,omitempty"` - FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` -} - -// InlineData represents base64-encoded data with its MIME type. -type InlineData struct { - MimeType string `json:"mime_type,omitempty"` - Data string `json:"data,omitempty"` -} - -// FunctionCall represents a tool call requested by the model, including the -// function name and its arguments. -type FunctionCall struct { - Name string `json:"name"` - Args map[string]interface{} `json:"args"` -} - -// FunctionResponse represents the result of a tool execution, sent back to the model. -type FunctionResponse struct { - Name string `json:"name"` - Response map[string]interface{} `json:"response"` -} - -// GenerateContentRequest is the top-level request structure for the streamGenerateContent endpoint. -type GenerateContentRequest struct { - SystemInstruction *Content `json:"systemInstruction,omitempty"` - Contents []Content `json:"contents"` - Tools []ToolDeclaration `json:"tools,omitempty"` - GenerationConfig `json:"generationConfig"` -} - -// GenerationConfig defines parameters that control the model's generation behavior. -type GenerationConfig struct { - ThinkingConfig GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK float64 `json:"topK,omitempty"` -} - -// GenerationConfigThinkingConfig specifies configuration for the model's "thinking" process. -type GenerationConfigThinkingConfig struct { - // IncludeThoughts determines whether the model should output its reasoning process. - IncludeThoughts bool `json:"include_thoughts,omitempty"` -} - -// ToolDeclaration defines the structure for declaring tools (like functions) -// that the model can call. -type ToolDeclaration struct { - FunctionDeclarations []interface{} `json:"functionDeclarations"` -} diff --git a/internal/cmd/login.go b/internal/cmd/login.go index 5d98e160..c7599fae 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -5,27 +5,33 @@ package cmd import ( "context" - "github.com/luispater/CLIProxyAPI/internal/auth" + "os" + + "github.com/luispater/CLIProxyAPI/internal/auth/gemini" "github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/config" log "github.com/sirupsen/logrus" - "os" ) // DoLogin handles the entire user login and setup process. // It authenticates the user, sets up the user's project, checks API enablement, // and saves the token for future use. -func DoLogin(cfg *config.Config, projectID string) { +func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + var err error - var ts auth.TokenStorage + var ts gemini.GeminiTokenStorage if projectID != "" { ts.ProjectID = projectID } // Initialize an authenticated HTTP client. This will trigger the OAuth flow if necessary. clientCtx := context.Background() - log.Info("Initializing authentication...") - httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg) + log.Info("Initializing Google authentication...") + geminiAuth := gemini.NewGeminiAuth() + httpClient, errGetClient := geminiAuth.GetAuthenticatedClient(clientCtx, &ts, cfg, options.NoBrowser) if errGetClient != nil { log.Fatalf("failed to get authenticated client: %v", errGetClient) return @@ -33,7 +39,7 @@ func DoLogin(cfg *config.Config, projectID string) { log.Info("Authentication successful.") // Initialize the API client. - cliClient := client.NewClient(httpClient, &ts, cfg) + cliClient := client.NewGeminiClient(httpClient, &ts, cfg) // Perform the user setup process. err = cliClient.SetupUser(clientCtx, ts.Email, projectID) diff --git a/internal/cmd/openai_login.go b/internal/cmd/openai_login.go new file mode 100644 index 00000000..ec4ba6c6 --- /dev/null +++ b/internal/cmd/openai_login.go @@ -0,0 +1,173 @@ +package cmd + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "net/http" + "os" + "strings" + "time" + + "github.com/luispater/CLIProxyAPI/internal/auth/codex" + "github.com/luispater/CLIProxyAPI/internal/browser" + "github.com/luispater/CLIProxyAPI/internal/client" + "github.com/luispater/CLIProxyAPI/internal/config" + log "github.com/sirupsen/logrus" +) + +// LoginOptions contains options for login +type LoginOptions struct { + NoBrowser bool +} + +// DoCodexLogin handles the Codex OAuth login process +func DoCodexLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + ctx := context.Background() + + log.Info("Initializing Codex authentication...") + + // Generate PKCE codes + pkceCodes, err := codex.GeneratePKCECodes() + if err != nil { + log.Fatalf("Failed to generate PKCE codes: %v", err) + return + } + + // Generate random state parameter + state, err := generateRandomState() + if err != nil { + log.Fatalf("Failed to generate state parameter: %v", err) + return + } + + // Initialize OAuth server + oauthServer := codex.NewOAuthServer(1455) + + // Start OAuth callback server + if err = oauthServer.Start(ctx); err != nil { + if strings.Contains(err.Error(), "already in use") { + authErr := codex.NewAuthenticationError(codex.ErrPortInUse, err) + log.Error(codex.GetUserFriendlyMessage(authErr)) + os.Exit(13) // Exit code 13 for port-in-use error + } + authErr := codex.NewAuthenticationError(codex.ErrServerStartFailed, err) + log.Fatalf("Failed to start OAuth callback server: %v", authErr) + return + } + defer func() { + if err = oauthServer.Stop(ctx); err != nil { + log.Warnf("Failed to stop OAuth server: %v", err) + } + }() + + // Initialize Codex auth service + openaiAuth := codex.NewCodexAuth(cfg) + + // Generate authorization URL + authURL, err := openaiAuth.GenerateAuthURL(state, pkceCodes) + if err != nil { + log.Fatalf("Failed to generate authorization URL: %v", err) + return + } + + // Open browser or display URL + if !options.NoBrowser { + log.Info("Opening browser for authentication...") + + // Check if browser is available + if !browser.IsAvailable() { + log.Warn("No browser available on this system") + log.Infof("Please manually open this URL in your browser:\n\n%s\n", authURL) + } else { + if err = browser.OpenURL(authURL); err != nil { + authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err) + log.Warn(codex.GetUserFriendlyMessage(authErr)) + log.Infof("Please manually open this URL in your browser:\n\n%s\n", authURL) + + // Log platform info for debugging + platformInfo := browser.GetPlatformInfo() + log.Debugf("Browser platform info: %+v", platformInfo) + } else { + log.Debug("Browser opened successfully") + } + } + } else { + log.Infof("Please open this URL in your browser:\n\n%s\n", authURL) + } + + log.Info("Waiting for authentication callback...") + + // Wait for OAuth callback + result, err := oauthServer.WaitForCallback(5 * time.Minute) + if err != nil { + if strings.Contains(err.Error(), "timeout") { + authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) + log.Error(codex.GetUserFriendlyMessage(authErr)) + } else { + log.Errorf("Authentication failed: %v", err) + } + return + } + + if result.Error != "" { + oauthErr := codex.NewOAuthError(result.Error, "", http.StatusBadRequest) + log.Error(codex.GetUserFriendlyMessage(oauthErr)) + return + } + + // Validate state parameter + if result.State != state { + authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, result.State)) + log.Error(codex.GetUserFriendlyMessage(authErr)) + return + } + + log.Debug("Authorization code received, exchanging for tokens...") + + // Exchange authorization code for tokens + authBundle, err := openaiAuth.ExchangeCodeForTokens(ctx, result.Code, pkceCodes) + if err != nil { + authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err) + log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) + log.Debug("This may be due to network issues or invalid authorization code") + return + } + + // Create token storage + tokenStorage := openaiAuth.CreateTokenStorage(authBundle) + + // Initialize Codex client + openaiClient, err := client.NewCodexClient(cfg, tokenStorage) + if err != nil { + log.Fatalf("Failed to initialize Codex client: %v", err) + return + } + + // Save token storage + if err = openaiClient.SaveTokenToFile(); err != nil { + log.Fatalf("Failed to save authentication tokens: %v", err) + return + } + + log.Info("Authentication successful!") + if authBundle.APIKey != "" { + log.Info("API key obtained and saved") + } + + log.Info("You can now use Codex services through this CLI") +} + +// generateRandomState generates a cryptographically secure random state parameter +func generateRandomState() (string, error) { + bytes := make([]byte, 16) + if _, err := rand.Read(bytes); err != nil { + return "", fmt.Errorf("failed to generate random bytes: %w", err) + } + return hex.EncodeToString(bytes), nil +} diff --git a/internal/cmd/run.go b/internal/cmd/run.go index b946bbe1..03b6677d 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -8,29 +8,37 @@ package cmd import ( "context" "encoding/json" - "github.com/luispater/CLIProxyAPI/internal/api" - "github.com/luispater/CLIProxyAPI/internal/auth" - "github.com/luispater/CLIProxyAPI/internal/client" - "github.com/luispater/CLIProxyAPI/internal/config" - "github.com/luispater/CLIProxyAPI/internal/util" - "github.com/luispater/CLIProxyAPI/internal/watcher" - log "github.com/sirupsen/logrus" "io/fs" "net/http" "os" "os/signal" "path/filepath" "strings" + "sync" "syscall" "time" + + "github.com/luispater/CLIProxyAPI/internal/api" + "github.com/luispater/CLIProxyAPI/internal/auth/codex" + "github.com/luispater/CLIProxyAPI/internal/auth/gemini" + "github.com/luispater/CLIProxyAPI/internal/client" + "github.com/luispater/CLIProxyAPI/internal/config" + "github.com/luispater/CLIProxyAPI/internal/util" + "github.com/luispater/CLIProxyAPI/internal/watcher" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" ) // StartService initializes and starts the main API proxy service. // It loads all available authentication tokens, creates a pool of clients, // starts the API server, and handles graceful shutdown signals. +// +// Parameters: +// - cfg: The application configuration +// - configPath: The path to the configuration file func StartService(cfg *config.Config, configPath string) { // Create a pool of API clients, one for each token file found. - cliClients := make([]*client.Client, 0) + cliClients := make([]client.Client, 0) err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error { if err != nil { return err @@ -39,31 +47,51 @@ func StartService(cfg *config.Config, configPath string) { // Process only JSON files in the auth directory. if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") { log.Debugf("Loading token from: %s", path) - f, errOpen := os.Open(path) - if errOpen != nil { - return errOpen + data, errReadFile := os.ReadFile(path) + if errReadFile != nil { + return errReadFile } - defer func() { - _ = f.Close() - }() - // Decode the token storage file. - var ts auth.TokenStorage - if err = json.NewDecoder(f).Decode(&ts); err == nil { - // For each valid token, create an authenticated client. - clientCtx := context.Background() - log.Info("Initializing authentication for token...") - httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg) - if errGetClient != nil { - // Log fatal will exit, but we return the error for completeness. - log.Fatalf("failed to get authenticated client for token %s: %v", path, errGetClient) - return errGetClient + tokenType := "gemini" + typeResult := gjson.GetBytes(data, "type") + if typeResult.Exists() { + tokenType = typeResult.String() + } + + clientCtx := context.Background() + + if tokenType == "gemini" { + var ts gemini.GeminiTokenStorage + if err = json.Unmarshal(data, &ts); err == nil { + // For each valid token, create an authenticated client. + log.Info("Initializing gemini authentication for token...") + geminiAuth := gemini.NewGeminiAuth() + httpClient, errGetClient := geminiAuth.GetAuthenticatedClient(clientCtx, &ts, cfg) + if errGetClient != nil { + // Log fatal will exit, but we return the error for completeness. + log.Fatalf("failed to get authenticated client for token %s: %v", path, errGetClient) + return errGetClient + } + log.Info("Authentication successful.") + + // Add the new client to the pool. + cliClient := client.NewGeminiClient(httpClient, &ts, cfg) + cliClients = append(cliClients, cliClient) + } + } else if tokenType == "codex" { + var ts codex.CodexTokenStorage + if err = json.Unmarshal(data, &ts); err == nil { + // For each valid token, create an authenticated client. + log.Info("Initializing codex authentication for token...") + codexClient, errGetClient := client.NewCodexClient(cfg, &ts) + if errGetClient != nil { + // Log fatal will exit, but we return the error for completeness. + log.Fatalf("failed to get authenticated client for token %s: %v", path, errGetClient) + return errGetClient + } + log.Info("Authentication successful.") + cliClients = append(cliClients, codexClient) } - log.Info("Authentication successful.") - - // Add the new client to the pool. - cliClient := client.NewClient(httpClient, &ts, cfg) - cliClients = append(cliClients, cliClient) } } return nil @@ -74,13 +102,10 @@ func StartService(cfg *config.Config, configPath string) { if len(cfg.GlAPIKey) > 0 { for i := 0; i < len(cfg.GlAPIKey); i++ { - httpClient, errSetProxy := util.SetProxy(cfg, &http.Client{}) - if errSetProxy != nil { - log.Fatalf("set proxy failed: %v", errSetProxy) - } + httpClient := util.SetProxy(cfg, &http.Client{}) log.Debug("Initializing with Generative Language API key...") - cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i]) + cliClient := client.NewGeminiClient(httpClient, nil, cfg, cfg.GlAPIKey[i]) cliClients = append(cliClients, cliClient) } } @@ -101,7 +126,7 @@ func StartService(cfg *config.Config, configPath string) { log.Info("API server started successfully") // Setup file watcher for config and auth directory changes - fileWatcher, errNewWatcher := watcher.NewWatcher(configPath, cfg.AuthDir, func(newClients []*client.Client, newCfg *config.Config) { + fileWatcher, errNewWatcher := watcher.NewWatcher(configPath, cfg.AuthDir, func(newClients []client.Client, newCfg *config.Config) { // Update the API server with new clients and configuration apiServer.UpdateClients(newClients, newCfg) }) @@ -132,12 +157,50 @@ func StartService(cfg *config.Config, configPath string) { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + // Background token refresh ticker for Codex clients + ctxRefresh, cancelRefresh := context.WithCancel(context.Background()) + var wgRefresh sync.WaitGroup + wgRefresh.Add(1) + go func() { + defer wgRefresh.Done() + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + checkAndRefresh := func() { + for i := 0; i < len(cliClients); i++ { + if codexCli, ok := cliClients[i].(*client.CodexClient); ok { + ts := codexCli.TokenStorage().(*codex.CodexTokenStorage) + if ts != nil && ts.Expire != "" { + if expTime, errParse := time.Parse(time.RFC3339, ts.Expire); errParse == nil { + if time.Until(expTime) <= 5*24*time.Hour { + log.Debugf("refreshing codex tokens for %s", codexCli.GetEmail()) + _ = codexCli.RefreshTokens(ctxRefresh) + } + } + } + } + } + } + // Initial check on start + checkAndRefresh() + for { + select { + case <-ctxRefresh.Done(): + return + case <-ticker.C: + checkAndRefresh() + } + } + }() + // Main loop to wait for shutdown signal. for { select { case <-sigChan: log.Debugf("Received shutdown signal. Cleaning up...") + cancelRefresh() + wgRefresh.Wait() + // Create a context with a timeout for the shutdown process. ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) _ = cancel @@ -150,8 +213,6 @@ func StartService(cfg *config.Config, configPath string) { log.Debugf("Cleanup completed. Exiting...") os.Exit(0) case <-time.After(5 * time.Second): - // This case is currently empty and acts as a periodic check. - // It could be used for periodic tasks in the future. } } } diff --git a/internal/config/config.go b/internal/config/config.go index 0e8368a3..3cd22fda 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,26 +6,36 @@ package config import ( "fmt" - "gopkg.in/yaml.v3" "os" + + "gopkg.in/yaml.v3" ) // Config represents the application's configuration, loaded from a YAML file. type Config struct { // Port is the network port on which the API server will listen. Port int `yaml:"port"` + // AuthDir is the directory where authentication token files are stored. AuthDir string `yaml:"auth-dir"` + // Debug enables or disables debug-level logging and other debug features. Debug bool `yaml:"debug"` + // ProxyURL is the URL of an optional proxy server to use for outbound requests. ProxyURL string `yaml:"proxy-url"` + // APIKeys is a list of keys for authenticating clients to this proxy server. APIKeys []string `yaml:"api-keys"` + // QuotaExceeded defines the behavior when a quota is exceeded. QuotaExceeded QuotaExceeded `yaml:"quota-exceeded"` + // GlAPIKey is the API key for the generative language API. GlAPIKey []string `yaml:"generative-language-api-key"` + + // RequestLog enables or disables detailed request logging functionality. + RequestLog bool `yaml:"request-log"` } // QuotaExceeded defines the behavior when API quota limits are exceeded. @@ -33,12 +43,21 @@ type Config struct { type QuotaExceeded struct { // SwitchProject indicates whether to automatically switch to another project when a quota is exceeded. SwitchProject bool `yaml:"switch-project"` + // SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded. SwitchPreviewModel bool `yaml:"switch-preview-model"` } // LoadConfig reads a YAML configuration file from the given path, -// unmarshals it into a Config struct, and returns it. +// unmarshals it into a Config struct, applies environment variable overrides, +// and returns it. +// +// Parameters: +// - configFile: The path to the YAML configuration file +// +// Returns: +// - *Config: The loaded configuration +// - error: An error if the configuration could not be loaded func LoadConfig(configFile string) (*Config, error) { // Read the entire configuration file into memory. data, err := os.ReadFile(configFile) diff --git a/internal/logging/request_logger.go b/internal/logging/request_logger.go new file mode 100644 index 00000000..a719ce96 --- /dev/null +++ b/internal/logging/request_logger.go @@ -0,0 +1,390 @@ +// Package logging provides request logging functionality for the CLI Proxy API server. +// It handles capturing and storing detailed HTTP request and response data when enabled +// through configuration, supporting both regular and streaming responses. +package logging + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "fmt" + "io" + "os" + "path/filepath" + "regexp" + "strings" + "time" +) + +// RequestLogger defines the interface for logging HTTP requests and responses. +type RequestLogger interface { + // LogRequest logs a complete non-streaming request/response cycle + LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response []byte) error + + // LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks + LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error) + + // IsEnabled returns whether request logging is currently enabled + IsEnabled() bool +} + +// StreamingLogWriter handles real-time logging of streaming response chunks. +type StreamingLogWriter interface { + // WriteChunkAsync writes a response chunk asynchronously (non-blocking) + WriteChunkAsync(chunk []byte) + + // WriteStatus writes the response status and headers to the log + WriteStatus(status int, headers map[string][]string) error + + // Close finalizes the log file and cleans up resources + Close() error +} + +// FileRequestLogger implements RequestLogger using file-based storage. +type FileRequestLogger struct { + enabled bool + logsDir string +} + +// NewFileRequestLogger creates a new file-based request logger. +func NewFileRequestLogger(enabled bool, logsDir string) *FileRequestLogger { + return &FileRequestLogger{ + enabled: enabled, + logsDir: logsDir, + } +} + +// IsEnabled returns whether request logging is currently enabled. +func (l *FileRequestLogger) IsEnabled() bool { + return l.enabled +} + +// LogRequest logs a complete non-streaming request/response cycle to a file. +func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response []byte) error { + if !l.enabled { + return nil + } + + // Ensure logs directory exists + if err := l.ensureLogsDir(); err != nil { + return fmt.Errorf("failed to create logs directory: %w", err) + } + + // Generate filename + filename := l.generateFilename(url) + filePath := filepath.Join(l.logsDir, filename) + + // Decompress response if needed + decompressedResponse, err := l.decompressResponse(responseHeaders, response) + if err != nil { + // If decompression fails, log the error but continue with original response + decompressedResponse = append(response, []byte(fmt.Sprintf("\n[DECOMPRESSION ERROR: %v]", err))...) + } + + // Create log content + content := l.formatLogContent(url, method, requestHeaders, body, decompressedResponse, statusCode, responseHeaders) + + // Write to file + if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { + return fmt.Errorf("failed to write log file: %w", err) + } + + return nil +} + +// LogStreamingRequest initiates logging for a streaming request. +func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error) { + if !l.enabled { + return &NoOpStreamingLogWriter{}, nil + } + + // Ensure logs directory exists + if err := l.ensureLogsDir(); err != nil { + return nil, fmt.Errorf("failed to create logs directory: %w", err) + } + + // Generate filename + filename := l.generateFilename(url) + filePath := filepath.Join(l.logsDir, filename) + + // Create and open file + file, err := os.Create(filePath) + if err != nil { + return nil, fmt.Errorf("failed to create log file: %w", err) + } + + // Write initial request information + requestInfo := l.formatRequestInfo(url, method, headers, body) + if _, err := file.WriteString(requestInfo); err != nil { + _ = file.Close() + return nil, fmt.Errorf("failed to write request info: %w", err) + } + + // Create streaming writer + writer := &FileStreamingLogWriter{ + file: file, + chunkChan: make(chan []byte, 100), // Buffered channel for async writes + closeChan: make(chan struct{}), + errorChan: make(chan error, 1), + } + + // Start async writer goroutine + go writer.asyncWriter() + + return writer, nil +} + +// ensureLogsDir creates the logs directory if it doesn't exist. +func (l *FileRequestLogger) ensureLogsDir() error { + if _, err := os.Stat(l.logsDir); os.IsNotExist(err) { + return os.MkdirAll(l.logsDir, 0755) + } + return nil +} + +// generateFilename creates a sanitized filename from the URL path and current timestamp. +func (l *FileRequestLogger) generateFilename(url string) string { + // Extract path from URL + path := url + if strings.Contains(url, "?") { + path = strings.Split(url, "?")[0] + } + + // Remove leading slash + if strings.HasPrefix(path, "/") { + path = path[1:] + } + + // Sanitize path for filename + sanitized := l.sanitizeForFilename(path) + + // Add timestamp + timestamp := time.Now().UnixNano() + + return fmt.Sprintf("%s-%d.log", sanitized, timestamp) +} + +// sanitizeForFilename replaces characters that are not safe for filenames. +func (l *FileRequestLogger) sanitizeForFilename(path string) string { + // Replace slashes with hyphens + sanitized := strings.ReplaceAll(path, "/", "-") + + // Replace colons with hyphens + sanitized = strings.ReplaceAll(sanitized, ":", "-") + + // Replace other problematic characters with hyphens + reg := regexp.MustCompile(`[<>:"|?*\s]`) + sanitized = reg.ReplaceAllString(sanitized, "-") + + // Remove multiple consecutive hyphens + reg = regexp.MustCompile(`-+`) + sanitized = reg.ReplaceAllString(sanitized, "-") + + // Remove leading/trailing hyphens + sanitized = strings.Trim(sanitized, "-") + + // Handle empty result + if sanitized == "" { + sanitized = "root" + } + + return sanitized +} + +// formatLogContent creates the complete log content for non-streaming requests. +func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body []byte, response []byte, status int, responseHeaders map[string][]string) string { + var content strings.Builder + + // Request info + content.WriteString(l.formatRequestInfo(url, method, headers, body)) + + // Response section + content.WriteString("========================================\n") + content.WriteString("=== RESPONSE ===\n") + content.WriteString(fmt.Sprintf("Status: %d\n", status)) + + if responseHeaders != nil { + for key, values := range responseHeaders { + for _, value := range values { + content.WriteString(fmt.Sprintf("%s: %s\n", key, value)) + } + } + } + + content.WriteString("\n") + content.Write(response) + content.WriteString("\n") + + return content.String() +} + +// decompressResponse decompresses response data based on Content-Encoding header. +func (l *FileRequestLogger) decompressResponse(responseHeaders map[string][]string, response []byte) ([]byte, error) { + if responseHeaders == nil || len(response) == 0 { + return response, nil + } + + // Check Content-Encoding header + var contentEncoding string + for key, values := range responseHeaders { + if strings.ToLower(key) == "content-encoding" && len(values) > 0 { + contentEncoding = strings.ToLower(values[0]) + break + } + } + + switch contentEncoding { + case "gzip": + return l.decompressGzip(response) + case "deflate": + return l.decompressDeflate(response) + default: + // No compression or unsupported compression + return response, nil + } +} + +// decompressGzip decompresses gzip-encoded data. +func (l *FileRequestLogger) decompressGzip(data []byte) ([]byte, error) { + reader, err := gzip.NewReader(bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("failed to create gzip reader: %w", err) + } + defer reader.Close() + + decompressed, err := io.ReadAll(reader) + if err != nil { + return nil, fmt.Errorf("failed to decompress gzip data: %w", err) + } + + return decompressed, nil +} + +// decompressDeflate decompresses deflate-encoded data. +func (l *FileRequestLogger) decompressDeflate(data []byte) ([]byte, error) { + reader := flate.NewReader(bytes.NewReader(data)) + defer reader.Close() + + decompressed, err := io.ReadAll(reader) + if err != nil { + return nil, fmt.Errorf("failed to decompress deflate data: %w", err) + } + + return decompressed, nil +} + +// formatRequestInfo creates the request information section of the log. +func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte) string { + var content strings.Builder + + content.WriteString("=== REQUEST INFO ===\n") + content.WriteString(fmt.Sprintf("URL: %s\n", url)) + content.WriteString(fmt.Sprintf("Method: %s\n", method)) + content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano))) + content.WriteString("\n") + + content.WriteString("=== HEADERS ===\n") + for key, values := range headers { + for _, value := range values { + content.WriteString(fmt.Sprintf("%s: %s\n", key, value)) + } + } + content.WriteString("\n") + + content.WriteString("=== REQUEST BODY ===\n") + content.Write(body) + content.WriteString("\n\n") + + return content.String() +} + +// FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs. +type FileStreamingLogWriter struct { + file *os.File + chunkChan chan []byte + closeChan chan struct{} + errorChan chan error + statusWritten bool +} + +// WriteChunkAsync writes a response chunk asynchronously (non-blocking). +func (w *FileStreamingLogWriter) WriteChunkAsync(chunk []byte) { + if w.chunkChan == nil { + return + } + + // Make a copy of the chunk to avoid data races + chunkCopy := make([]byte, len(chunk)) + copy(chunkCopy, chunk) + + // Non-blocking send + select { + case w.chunkChan <- chunkCopy: + default: + // Channel is full, skip this chunk to avoid blocking + } +} + +// WriteStatus writes the response status and headers to the log. +func (w *FileStreamingLogWriter) WriteStatus(status int, headers map[string][]string) error { + if w.file == nil || w.statusWritten { + return nil + } + + var content strings.Builder + content.WriteString("========================================\n") + content.WriteString("=== RESPONSE ===\n") + content.WriteString(fmt.Sprintf("Status: %d\n", status)) + + for key, values := range headers { + for _, value := range values { + content.WriteString(fmt.Sprintf("%s: %s\n", key, value)) + } + } + content.WriteString("\n") + + _, err := w.file.WriteString(content.String()) + if err == nil { + w.statusWritten = true + } + return err +} + +// Close finalizes the log file and cleans up resources. +func (w *FileStreamingLogWriter) Close() error { + if w.chunkChan != nil { + close(w.chunkChan) + } + + // Wait for async writer to finish + if w.closeChan != nil { + <-w.closeChan + w.chunkChan = nil + } + + if w.file != nil { + return w.file.Close() + } + + return nil +} + +// asyncWriter runs in a goroutine to handle async chunk writing. +func (w *FileStreamingLogWriter) asyncWriter() { + defer close(w.closeChan) + + for chunk := range w.chunkChan { + if w.file != nil { + _, _ = w.file.Write(chunk) + } + } +} + +// NoOpStreamingLogWriter is a no-operation implementation for when logging is disabled. +type NoOpStreamingLogWriter struct{} + +func (w *NoOpStreamingLogWriter) WriteChunkAsync(chunk []byte) {} +func (w *NoOpStreamingLogWriter) WriteStatus(status int, headers map[string][]string) error { + return nil +} +func (w *NoOpStreamingLogWriter) Close() error { return nil } diff --git a/internal/misc/codex_instructions.go b/internal/misc/codex_instructions.go new file mode 100644 index 00000000..e4c88f40 --- /dev/null +++ b/internal/misc/codex_instructions.go @@ -0,0 +1,6 @@ +package misc + +import _ "embed" + +//go:embed codex_instructions.txt +var CodexInstructions string diff --git a/internal/misc/codex_instructions.txt b/internal/misc/codex_instructions.txt new file mode 100644 index 00000000..c95b946b --- /dev/null +++ b/internal/misc/codex_instructions.txt @@ -0,0 +1 @@ +"You are a coding agent running in the Codex CLI, a terminal-based coding assistant. Codex CLI is an open source project led by OpenAI. You are expected to be precise, safe, and helpful.\n\nYour capabilities:\n- Receive user prompts and other context provided by the harness, such as files in the workspace.\n- Communicate with the user by streaming thinking & responses, and by making & updating plans.\n- Emit function calls to run terminal commands and apply patches. Depending on how this specific run is configured, you can request that these function calls be escalated to the user for approval before running. More on this in the \"Sandbox and approvals\" section.\n\nWithin this context, Codex refers to the open-source agentic coding interface (not the old Codex language model built by OpenAI).\n\n# How you work\n\n## Personality\n\nYour default personality and tone is concise, direct, and friendly. You communicate efficiently, always keeping the user clearly informed about ongoing actions without unnecessary detail. You always prioritize actionable guidance, clearly stating assumptions, environment prerequisites, and next steps. Unless explicitly asked, you avoid excessively verbose explanations about your work.\n\n## Responsiveness\n\n### Preamble messages\n\nBefore making tool calls, send a brief preamble to the user explaining what you’re about to do. When sending preamble messages, follow these principles and examples:\n\n- **Logically group related actions**: if you’re about to run several related commands, describe them together in one preamble rather than sending a separate note for each.\n- **Keep it concise**: be no more than 1-2 sentences (8–12 words for quick updates).\n- **Build on prior context**: if this is not your first tool call, use the preamble message to connect the dots with what’s been done so far and create a sense of momentum and clarity for the user to understand your next actions.\n- **Keep your tone light, friendly and curious**: add small touches of personality in preambles feel collaborative and engaging.\n\n**Examples:**\n- “I’ve explored the repo; now checking the API route definitions.”\n- “Next, I’ll patch the config and update the related tests.”\n- “I’m about to scaffold the CLI commands and helper functions.”\n- “Ok cool, so I’ve wrapped my head around the repo. Now digging into the API routes.”\n- “Config’s looking tidy. Next up is patching helpers to keep things in sync.”\n- “Finished poking at the DB gateway. I will now chase down error handling.”\n- “Alright, build pipeline order is interesting. Checking how it reports failures.”\n- “Spotted a clever caching util; now hunting where it gets used.”\n\n**Avoiding a preamble for every trivial read (e.g., `cat` a single file) unless it’s part of a larger grouped action.\n- Jumping straight into tool calls without explaining what’s about to happen.\n- Writing overly long or speculative preambles — focus on immediate, tangible next steps.\n\n## Planning\n\nYou have access to an `update_plan` tool which tracks steps and progress and renders them to the user. Using the tool helps demonstrate that you've understood the task and convey how you're approaching it. Plans can help to make complex, ambiguous, or multi-phase work clearer and more collaborative for the user. A good plan should break the task into meaningful, logically ordered steps that are easy to verify as you go. Note that plans are not for padding out simple work with filler steps or stating the obvious. Do not repeat the full contents of the plan after an `update_plan` call — the harness already displays it. Instead, summarize the change made and highlight any important context or next step.\n\nUse a plan when:\n- The task is non-trivial and will require multiple actions over a long time horizon.\n- There are logical phases or dependencies where sequencing matters.\n- The work has ambiguity that benefits from outlining high-level goals.\n- You want intermediate checkpoints for feedback and validation.\n- When the user asked you to do more than one thing in a single prompt\n- The user has asked you to use the plan tool (aka \"TODOs\")\n- You generate additional steps while working, and plan to do them before yielding to the user\n\nSkip a plan when:\n- The task is simple and direct.\n- Breaking it down would only produce literal or trivial steps.\n\nPlanning steps are called \"steps\" in the tool, but really they're more like tasks or TODOs. As such they should be very concise descriptions of non-obvious work that an engineer might do like \"Write the API spec\", then \"Update the backend\", then \"Implement the frontend\". On the other hand, it's obvious that you'll usually have to \"Explore the codebase\" or \"Implement the changes\", so those are not worth tracking in your plan.\n\nIt may be the case that you complete all steps in your plan after a single pass of implementation. If this is the case, you can simply mark all the planned steps as completed. The content of your plan should not involve doing anything that you aren't capable of doing (i.e. don't try to test things that you can't test). Do not use plans for simple or single-step queries that you can just do or answer immediately.\n\n### Examples\n\n**High-quality plans**\n\nExample 1:\n\n1. Add CLI entry with file args\n2. Parse Markdown via CommonMark library\n3. Apply semantic HTML template\n4. Handle code blocks, images, links\n5. Add error handling for invalid files\n\nExample 2:\n\n1. Define CSS variables for colors\n2. Add toggle with localStorage state\n3. Refactor components to use variables\n4. Verify all views for readability\n5. Add smooth theme-change transition\n\nExample 3:\n\n1. Set up Node.js + WebSocket server\n2. Add join/leave broadcast events\n3. Implement messaging with timestamps\n4. Add usernames + mention highlighting\n5. Persist messages in lightweight DB\n6. Add typing indicators + unread count\n\n**Low-quality plans**\n\nExample 1:\n\n1. Create CLI tool\n2. Add Markdown parser\n3. Convert to HTML\n\nExample 2:\n\n1. Add dark mode toggle\n2. Save preference\n3. Make styles look good\n\nExample 3:\n\n1. Create single-file HTML game\n2. Run quick sanity check\n3. Summarize usage instructions\n\nIf you need to write a plan, only write high quality plans, not low quality ones.\n\n## Task execution\n\nYou are a coding agent. Please keep going until the query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved. Autonomously resolve the query to the best of your ability, using the tools available to you, before coming back to the user. Do NOT guess or make up an answer.\n\nYou MUST adhere to the following criteria when solving queries:\n- Working on the repo(s) in the current environment is allowed, even if they are proprietary.\n- Analyzing code for vulnerabilities is allowed.\n- Showing user code and tool call details is allowed.\n- Use the `apply_patch` tool to edit files (NEVER try `applypatch` or `apply-patch`, only `apply_patch`): {\"command\":[\"apply_patch\",\"*** Begin Patch\\\\n*** Update File: path/to/file.py\\\\n@@ def example():\\\\n- pass\\\\n+ return 123\\\\n*** End Patch\"]}\n\nIf completing the user's task requires writing or modifying files, your code and final answer should follow these coding guidelines, though user instructions (i.e. AGENTS.md) may override these guidelines:\n\n- Fix the problem at the root cause rather than applying surface-level patches, when possible.\n- Avoid unneeded complexity in your solution.\n- Do not attempt to fix unrelated bugs or broken tests. It is not your responsibility to fix them. (You may mention them to the user in your final message though.)\n- Update documentation as necessary.\n- Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task.\n- Use `git log` and `git blame` to search the history of the codebase if additional context is required.\n- NEVER add copyright or license headers unless specifically requested.\n- Do not waste tokens by re-reading files after calling `apply_patch` on them. The tool call will fail if it didn't work. The same goes for making folders, deleting folders, etc.\n- Do not `git commit` your changes or create new git branches unless explicitly requested.\n- Do not add inline comments within code unless explicitly requested.\n- Do not use one-letter variable names unless explicitly requested.\n- NEVER output inline citations like \"【F:README.md†L5-L14】\" in your outputs. The CLI is not able to render these so they will just be broken in the UI. Instead, if you output valid filepaths, users will be able to click on them to open the files in their editor.\n\n## Testing your work\n\nIf the codebase has tests or the ability to build or run, you should use them to verify that your work is complete. Generally, your testing philosophy should be to start as specific as possible to the code you changed so that you can catch issues efficiently, then make your way to broader tests as you build confidence. If there's no test for the code you changed, and if the adjacent patterns in the codebases show that there's a logical place for you to add a test, you may do so. However, do not add tests to codebases with no tests, or where the patterns don't indicate so.\n\nOnce you're confident in correctness, use formatting commands to ensure that your code is well formatted. These commands can take time so you should run them on as precise a target as possible. If there are issues you can iterate up to 3 times to get formatting right, but if you still can't manage it's better to save the user time and present them a correct solution where you call out the formatting in your final message. If the codebase does not have a formatter configured, do not add one.\n\nFor all of testing, running, building, and formatting, do not attempt to fix unrelated bugs. It is not your responsibility to fix them. (You may mention them to the user in your final message though.)\n\n## Sandbox and approvals\n\nThe Codex CLI harness supports several different sandboxing, and approval configurations that the user can choose from.\n\nFilesystem sandboxing prevents you from editing files without user approval. The options are:\n- *read-only*: You can only read files.\n- *workspace-write*: You can read files. You can write to files in your workspace folder, but not outside it.\n- *danger-full-access*: No filesystem sandboxing.\n\nNetwork sandboxing prevents you from accessing network without approval. Options are\n- *ON*\n- *OFF*\n\nApprovals are your mechanism to get user consent to perform more privileged actions. Although they introduce friction to the user because your work is paused until the user responds, you should leverage them to accomplish your important work. Do not let these settings or the sandbox deter you from attempting to accomplish the user's task. Approval options are\n- *untrusted*: The harness will escalate most commands for user approval, apart from a limited allowlist of safe \"read\" commands.\n- *on-failure*: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.\n- *on-request*: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)\n- *never*: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is pared with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.\n\nWhen you are running with approvals `on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:\n- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /tmp)\n- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.\n- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)\n- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval.\n- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for\n- (For all of these, you should weigh alternative paths that do not require approval.)\n\nNote that when sandboxing is set to read-only, you'll need to request approval for any command that isn't a read.\n\nYou will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing ON, and approval on-failure.\n\n## Ambition vs. precision\n\nFor tasks that have no prior context (i.e. the user is starting something brand new), you should feel free to be ambitious and demonstrate creativity with your implementation.\n\nIf you're operating in an existing codebase, you should make sure you do exactly what the user asks with surgical precision. Treat the surrounding codebase with respect, and don't overstep (i.e. changing filenames or variables unnecessarily). You should balance being sufficiently ambitious and proactive when completing tasks of this nature.\n\nYou should use judicious initiative to decide on the right level of detail and complexity to deliver based on the user's needs. This means showing good judgment that you're capable of doing the right extras without gold-plating. This might be demonstrated by high-value, creative touches when scope of the task is vague; while being surgical and targeted when scope is tightly specified.\n\n## Sharing progress updates\n\nFor especially longer tasks that you work on (i.e. requiring many tool calls, or a plan with multiple steps), you should provide progress updates back to the user at reasonable intervals. These updates should be structured as a concise sentence or two (no more than 8-10 words long) recapping progress so far in plain language: this update demonstrates your understanding of what needs to be done, progress so far (i.e. files explores, subtasks complete), and where you're going next.\n\nBefore doing large chunks of work that may incur latency as experienced by the user (i.e. writing a new file), you should send a concise message to the user with an update indicating what you're about to do to ensure they know what you're spending time on. Don't start editing or writing large files before informing the user what you are doing and why.\n\nThe messages you send before tool calls should describe what is immediately about to be done next in very concise language. If there was previous work done, this preamble message should also include a note about the work done so far to bring the user along.\n\n## Presenting your work and final message\n\nYour final message should read naturally, like an update from a concise teammate. For casual conversation, brainstorming tasks, or quick questions from the user, respond in a friendly, conversational tone. You should ask questions, suggest ideas, and adapt to the user’s style. If you've finished a large amount of work, when describing what you've done to the user, you should follow the final answer formatting guidelines to communicate substantive changes. You don't need to add structured formatting for one-word answers, greetings, or purely conversational exchanges.\n\nYou can skip heavy formatting for single, simple actions or confirmations. In these cases, respond in plain sentences with any relevant next step or quick option. Reserve multi-section structured responses for results that need grouping or explanation.\n\nThe user is working on the same computer as you, and has access to your work. As such there's no need to show the full contents of large files you have already written unless the user explicitly asks for them. Similarly, if you've created or modified files using `apply_patch`, there's no need to tell users to \"save the file\" or \"copy the code into a file\"—just reference the file path.\n\nIf there's something that you think you could help with as a logical next step, concisely ask the user if they want you to do so. Good examples of this are running tests, committing changes, or building out the next logical component. If there’s something that you couldn't do (even with approval) but that the user might want to do (such as verifying changes by running the app), include those instructions succinctly.\n\nBrevity is very important as a default. You should be very concise (i.e. no more than 10 lines), but can relax this requirement for tasks where additional detail and comprehensiveness is important for the user's understanding.\n\n### Final answer structure and style guidelines\n\nYou are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.\n\n**Section Headers**\n- Use only when they improve clarity — they are not mandatory for every answer.\n- Choose descriptive names that fit the content\n- Keep headers short (1–3 words) and in `**Title Case**`. Always start headers with `**` and end with `**`\n- Leave no blank line before the first bullet under a header.\n- Section headers should only be used where they genuinely improve scanability; avoid fragmenting the answer.\n\n**Bullets**\n- Use `-` followed by a space for every bullet.\n- Bold the keyword, then colon + concise description.\n- Merge related points when possible; avoid a bullet for every trivial detail.\n- Keep bullets to one line unless breaking for clarity is unavoidable.\n- Group into short lists (4–6 bullets) ordered by importance.\n- Use consistent keyword phrasing and formatting across sections.\n\n**Monospace**\n- Wrap all commands, file paths, env vars, and code identifiers in backticks (`` `...` ``).\n- Apply to inline examples and to bullet keywords if the keyword itself is a literal file/command.\n- Never mix monospace and bold markers; choose one based on whether it’s a keyword (`**`) or inline code/path (`` ` ``).\n\n**Structure**\n- Place related bullets together; don’t mix unrelated concepts in the same section.\n- Order sections from general → specific → supporting info.\n- For subsections (e.g., “Binaries” under “Rust Workspace”), introduce with a bolded keyword bullet, then list items under it.\n- Match structure to complexity:\n - Multi-part or detailed results → use clear headers and grouped bullets.\n - Simple results → minimal headers, possibly just a short list or paragraph.\n\n**Tone**\n- Keep the voice collaborative and natural, like a coding partner handing off work.\n- Be concise and factual — no filler or conversational commentary and avoid unnecessary repetition\n- Use present tense and active voice (e.g., “Runs tests” not “This will run tests”).\n- Keep descriptions self-contained; don’t refer to “above” or “below”.\n- Use parallel structure in lists for consistency.\n\n**Don’t**\n- Don’t use literal words “bold” or “monospace” in the content.\n- Don’t nest bullets or create deep hierarchies.\n- Don’t output ANSI escape codes directly — the CLI renderer applies them.\n- Don’t cram unrelated keywords into a single bullet; split for clarity.\n- Don’t let keyword lists run long — wrap or reformat for scanability.\n\nGenerally, ensure your final answers adapt their shape and depth to the request. For example, answers to code explanations should have a precise, structured explanation with code references that answer the question directly. For tasks with a simple implementation, lead with the outcome and supplement only with what’s needed for clarity. Larger changes can be presented as a logical walkthrough of your approach, grouping related steps, explaining rationale where it adds value, and highlighting next actions to accelerate the user. Your answers should provide the right level of detail while being easily scannable.\n\nFor casual greetings, acknowledgements, or other one-off conversational messages that are not delivering substantive information or structured results, respond naturally without section headers or bullet formatting.\n\n# Tools\n\n## `apply_patch`\n\nYour patch language is a stripped‑down, file‑oriented diff format designed to be easy to parse and safe to apply. You can think of it as a high‑level envelope:\n\n**_ Begin Patch\n[ one or more file sections ]\n_** End Patch\n\nWithin that envelope, you get a sequence of file operations.\nYou MUST include a header to specify the action you are taking.\nEach operation starts with one of three headers:\n\n**_ Add File: - create a new file. Every following line is a + line (the initial contents).\n_** Delete File: - remove an existing file. Nothing follows.\n\\*\\*\\* Update File: - patch an existing file in place (optionally with a rename).\n\nMay be immediately followed by \\*\\*\\* Move to: if you want to rename the file.\nThen one or more “hunks”, each introduced by @@ (optionally followed by a hunk header).\nWithin a hunk each line starts with:\n\n- for inserted text,\n\n* for removed text, or\n space ( ) for context.\n At the end of a truncated hunk you can emit \\*\\*\\* End of File.\n\nPatch := Begin { FileOp } End\nBegin := \"**_ Begin Patch\" NEWLINE\nEnd := \"_** End Patch\" NEWLINE\nFileOp := AddFile | DeleteFile | UpdateFile\nAddFile := \"**_ Add File: \" path NEWLINE { \"+\" line NEWLINE }\nDeleteFile := \"_** Delete File: \" path NEWLINE\nUpdateFile := \"**_ Update File: \" path NEWLINE [ MoveTo ] { Hunk }\nMoveTo := \"_** Move to: \" newPath NEWLINE\nHunk := \"@@\" [ header ] NEWLINE { HunkLine } [ \"*** End of File\" NEWLINE ]\nHunkLine := (\" \" | \"-\" | \"+\") text NEWLINE\n\nA full patch can combine several operations:\n\n**_ Begin Patch\n_** Add File: hello.txt\n+Hello world\n**_ Update File: src/app.py\n_** Move to: src/main.py\n@@ def greet():\n-print(\"Hi\")\n+print(\"Hello, world!\")\n**_ Delete File: obsolete.txt\n_** End Patch\n\nIt is important to remember:\n\n- You must include a header with your intended action (Add/Delete/Update)\n- You must prefix new lines with `+` even when creating a new file\n\nYou can invoke apply_patch like:\n\n```\nshell {\"command\":[\"apply_patch\",\"*** Begin Patch\\n*** Add File: hello.txt\\n+Hello, world!\\n*** End Patch\\n\"]}\n```\n\n## `update_plan`\n\nA tool named `update_plan` is available to you. You can use it to keep an up‑to‑date, step‑by‑step plan for the task.\n\nTo create a new plan, call `update_plan` with a short list of 1‑sentence steps (no more than 5-7 words each) with a `status` for each step (`pending`, `in_progress`, or `completed`).\n\nWhen steps have been completed, use `update_plan` to mark each finished step as `completed` and the next step you are working on as `in_progress`. There should always be exactly one `in_progress` step until everything is done. You can mark multiple items as complete in a single `update_plan` call.\n\nIf all steps are complete, ensure you call `update_plan` to mark all steps as `completed`.\n" \ No newline at end of file diff --git a/internal/api/translator/mime-type.go b/internal/misc/mime-type.go similarity index 99% rename from internal/api/translator/mime-type.go rename to internal/misc/mime-type.go index c467b183..dc6c9ef8 100644 --- a/internal/api/translator/mime-type.go +++ b/internal/misc/mime-type.go @@ -1,7 +1,7 @@ // Package translator provides data translation and format conversion utilities // for the CLI Proxy API. It includes MIME type mappings and other translation // functions used across different API endpoints. -package translator +package misc // MimeTypes is a comprehensive map of file extensions to their corresponding MIME types. // This is used to identify the type of file being uploaded or processed. diff --git a/internal/translator/codex/claude/code/codex_cc_request.go b/internal/translator/codex/claude/code/codex_cc_request.go new file mode 100644 index 00000000..57ef6f45 --- /dev/null +++ b/internal/translator/codex/claude/code/codex_cc_request.go @@ -0,0 +1,114 @@ +// Package code provides request translation functionality for Claude API. +// It handles parsing and transforming Claude API requests into the internal client format, +// extracting model information, system instructions, message contents, and tool declarations. +// The package also performs JSON data cleaning and transformation to ensure compatibility +// between Claude API format and the internal client's expected format. +package code + +import ( + "fmt" + + "github.com/luispater/CLIProxyAPI/internal/misc" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// PrepareClaudeRequest parses and transforms a Claude API request into internal client format. +// It extracts the model name, system instruction, message contents, and tool declarations +// from the raw JSON request and returns them in the format expected by the internal client. +func ConvertClaudeCodeRequestToCodex(rawJSON []byte) string { + template := `{"model":"","instructions":"","input":[]}` + + instructions := misc.CodexInstructions + template, _ = sjson.SetRaw(template, "instructions", instructions) + + rootResult := gjson.ParseBytes(rawJSON) + modelResult := rootResult.Get("model") + template, _ = sjson.Set(template, "model", modelResult.String()) + + systemsResult := rootResult.Get("system") + if systemsResult.IsArray() { + systemResults := systemsResult.Array() + message := `{"type":"message","role":"user","content":[]}` + for i := 0; i < len(systemResults); i++ { + systemResult := systemResults[i] + systemTypeResult := systemResult.Get("type") + if systemTypeResult.String() == "text" { + message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", i), "input_text") + message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", i), systemResult.Get("text").String()) + } + } + template, _ = sjson.SetRaw(template, "input.-1", message) + } + + messagesResult := rootResult.Get("messages") + if messagesResult.IsArray() { + messageResults := messagesResult.Array() + + for i := 0; i < len(messageResults); i++ { + messageResult := messageResults[i] + + messageContentsResult := messageResult.Get("content") + if messageContentsResult.IsArray() { + messageContentResults := messageContentsResult.Array() + for j := 0; j < len(messageContentResults); j++ { + messageContentResult := messageContentResults[j] + messageContentTypeResult := messageContentResult.Get("type") + if messageContentTypeResult.String() == "text" { + message := `{"type": "message","role":"","content":[]}` + messageRole := messageResult.Get("role").String() + message, _ = sjson.Set(message, "role", messageRole) + + partType := "input_text" + if messageRole == "assistant" { + partType = "output_text" + } + + currentIndex := len(gjson.Get(message, "content").Array()) + message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", currentIndex), partType) + message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", currentIndex), messageContentResult.Get("text").String()) + template, _ = sjson.SetRaw(template, "input.-1", message) + } else if messageContentTypeResult.String() == "tool_use" { + functionCallMessage := `{"type":"function_call"}` + functionCallMessage, _ = sjson.Set(functionCallMessage, "call_id", messageContentResult.Get("id").String()) + functionCallMessage, _ = sjson.Set(functionCallMessage, "name", messageContentResult.Get("name").String()) + functionCallMessage, _ = sjson.Set(functionCallMessage, "arguments", messageContentResult.Get("input").Raw) + template, _ = sjson.SetRaw(template, "input.-1", functionCallMessage) + } else if messageContentTypeResult.String() == "tool_result" { + functionCallOutputMessage := `{"type":"function_call_output"}` + functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String()) + functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String()) + template, _ = sjson.SetRaw(template, "input.-1", functionCallOutputMessage) + } + } + } + } + + } + + toolsResult := rootResult.Get("tools") + if toolsResult.IsArray() { + template, _ = sjson.SetRaw(template, "tools", `[]`) + template, _ = sjson.Set(template, "tool_choice", `auto`) + toolResults := toolsResult.Array() + for i := 0; i < len(toolResults); i++ { + toolResult := toolResults[i] + tool := toolResult.Raw + tool, _ = sjson.Set(tool, "type", "function") + tool, _ = sjson.SetRaw(tool, "parameters", toolResult.Get("input_schema").Raw) + tool, _ = sjson.Delete(tool, "input_schema") + tool, _ = sjson.Delete(tool, "parameters.$schema") + tool, _ = sjson.Set(tool, "strict", false) + template, _ = sjson.SetRaw(template, "tools.-1", tool) + } + } + + template, _ = sjson.Set(template, "parallel_tool_calls", true) + template, _ = sjson.Set(template, "reasoning.effort", "low") + template, _ = sjson.Set(template, "reasoning.summary", "auto") + template, _ = sjson.Set(template, "stream", true) + template, _ = sjson.Set(template, "store", false) + template, _ = sjson.Set(template, "include", []string{"reasoning.encrypted_content"}) + + return template +} diff --git a/internal/translator/codex/claude/code/codex_cc_response.go b/internal/translator/codex/claude/code/codex_cc_response.go new file mode 100644 index 00000000..af7cbc04 --- /dev/null +++ b/internal/translator/codex/claude/code/codex_cc_response.go @@ -0,0 +1,129 @@ +// Package code provides response translation functionality for Claude API. +// This package handles the conversion of backend client responses into Claude-compatible +// Server-Sent Events (SSE) format, implementing a sophisticated state machine that manages +// different response types including text content, thinking processes, and function calls. +// The translation ensures proper sequencing of SSE events and maintains state across +// multiple response chunks to provide a seamless streaming experience. +package code + +import ( + "fmt" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// ConvertCliToClaude performs sophisticated streaming response format conversion. +// This function implements a complex state machine that translates backend client responses +// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types +// and handles state transitions between content blocks, thinking processes, and function calls. +// +// Response type states: 0=none, 1=content, 2=thinking, 3=function +// The function maintains state across multiple calls to ensure proper SSE event sequencing. +func ConvertCodexResponseToClaude(rawJSON []byte, hasToolCall bool) (string, bool) { + // log.Debugf("rawJSON: %s", string(rawJSON)) + output := "" + rootResult := gjson.ParseBytes(rawJSON) + typeResult := rootResult.Get("type") + typeStr := typeResult.String() + template := "" + if typeStr == "response.created" { + template = `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}` + template, _ = sjson.Set(template, "message.model", rootResult.Get("response.model").String()) + template, _ = sjson.Set(template, "message.id", rootResult.Get("response.id").String()) + + output = "event: message_start\n" + output += fmt.Sprintf("data: %s\n", template) + } else if typeStr == "response.reasoning_summary_part.added" { + template = `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}` + template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + + output = "event: content_block_start\n" + output += fmt.Sprintf("data: %s\n", template) + } else if typeStr == "response.reasoning_summary_text.delta" { + template = `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}` + template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + template, _ = sjson.Set(template, "delta.thinking", rootResult.Get("delta").String()) + + output = "event: content_block_delta\n" + output += fmt.Sprintf("data: %s\n", template) + } else if typeStr == "response.reasoning_summary_part.done" { + template = `{"type":"content_block_stop","index":0}` + template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + + output = "event: content_block_stop\n" + output += fmt.Sprintf("data: %s\n", template) + } else if typeStr == "response.content_part.added" { + template = `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}` + template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + + output = "event: content_block_start\n" + output += fmt.Sprintf("data: %s\n", template) + } else if typeStr == "response.output_text.delta" { + template = `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}` + template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + template, _ = sjson.Set(template, "delta.text", rootResult.Get("delta").String()) + + output = "event: content_block_delta\n" + output += fmt.Sprintf("data: %s\n", template) + } else if typeStr == "response.content_part.done" { + template = `{"type":"content_block_stop","index":0}` + template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + + output = "event: content_block_stop\n" + output += fmt.Sprintf("data: %s\n", template) + } else if typeStr == "response.completed" { + template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` + if hasToolCall { + template, _ = sjson.Set(template, "delta.stop_reason", "tool_use") + } else { + template, _ = sjson.Set(template, "delta.stop_reason", "end_turn") + } + template, _ = sjson.Set(template, "usage.input_tokens", rootResult.Get("response.usage.input_tokens").Int()) + template, _ = sjson.Set(template, "usage.output_tokens", rootResult.Get("response.usage.output_tokens").Int()) + + output = "event: message_delta\n" + output += fmt.Sprintf("data: %s\n\n", template) + output += "event: message_stop\n" + output += `data: {"type":"message_stop"}` + output += "\n\n" + } else if typeStr == "response.output_item.added" { + itemResult := rootResult.Get("item") + itemType := itemResult.Get("type").String() + if itemType == "function_call" { + hasToolCall = true + template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}` + template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String()) + template, _ = sjson.Set(template, "content_block.name", itemResult.Get("name").String()) + + output = "event: content_block_start\n" + output += fmt.Sprintf("data: %s\n\n", template) + + template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}` + template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + + output += "event: content_block_delta\n" + output += fmt.Sprintf("data: %s\n", template) + } + } else if typeStr == "response.output_item.done" { + itemResult := rootResult.Get("item") + itemType := itemResult.Get("type").String() + if itemType == "function_call" { + template = `{"type":"content_block_stop","index":0}` + template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + + output = "event: content_block_stop\n" + output += fmt.Sprintf("data: %s\n", template) + } + } else if typeStr == "response.function_call_arguments.delta" { + template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}` + template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) + template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String()) + + output += "event: content_block_delta\n" + output += fmt.Sprintf("data: %s\n", template) + } + + return output, hasToolCall +} diff --git a/internal/translator/codex/gemini/codex_gemini_request.go b/internal/translator/codex/gemini/codex_gemini_request.go new file mode 100644 index 00000000..61c395d6 --- /dev/null +++ b/internal/translator/codex/gemini/codex_gemini_request.go @@ -0,0 +1,199 @@ +// Package code provides request translation functionality for Claude API. +// It handles parsing and transforming Claude API requests into the internal client format, +// extracting model information, system instructions, message contents, and tool declarations. +// The package also performs JSON data cleaning and transformation to ensure compatibility +// between Claude API format and the internal client's expected format. +package code + +import ( + "crypto/rand" + "math/big" + "strings" + + "github.com/luispater/CLIProxyAPI/internal/misc" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// PrepareClaudeRequest parses and transforms a Claude API request into internal client format. +// It extracts the model name, system instruction, message contents, and tool declarations +// from the raw JSON request and returns them in the format expected by the internal client. +func ConvertGeminiRequestToCodex(rawJSON []byte) string { + // Base template + out := `{"model":"","instructions":"","input":[]}` + + // Inject standard Codex instructions + instructions := misc.CodexInstructions + out, _ = sjson.SetRaw(out, "instructions", instructions) + + root := gjson.ParseBytes(rawJSON) + + // helper for generating paired call IDs in the form: call_ + // Gemini uses sequential pairing across possibly multiple in-flight + // functionCalls, so we keep a FIFO queue of generated call IDs and + // consume them in order when functionResponses arrive. + var pendingCallIDs []string + + // genCallID creates a random call id like: call_<8chars> + genCallID := func() string { + const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + var b strings.Builder + // 8 chars random suffix + for i := 0; i < 24; i++ { + n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters)))) + b.WriteByte(letters[n.Int64()]) + } + return "call_" + b.String() + } + + // Model + if v := root.Get("model"); v.Exists() { + out, _ = sjson.Set(out, "model", v.Value()) + } + + // System instruction -> as a user message with input_text parts + sysParts := root.Get("system_instruction.parts") + if sysParts.IsArray() { + msg := `{"type":"message","role":"user","content":[]}` + arr := sysParts.Array() + for i := 0; i < len(arr); i++ { + p := arr[i] + if t := p.Get("text"); t.Exists() { + part := `{}` + part, _ = sjson.Set(part, "type", "input_text") + part, _ = sjson.Set(part, "text", t.String()) + msg, _ = sjson.SetRaw(msg, "content.-1", part) + } + } + if len(gjson.Get(msg, "content").Array()) > 0 { + out, _ = sjson.SetRaw(out, "input.-1", msg) + } + } + + // Contents -> messages and function calls/results + contents := root.Get("contents") + if contents.IsArray() { + items := contents.Array() + for i := 0; i < len(items); i++ { + item := items[i] + role := item.Get("role").String() + if role == "model" { + role = "assistant" + } + + parts := item.Get("parts") + if !parts.IsArray() { + continue + } + parr := parts.Array() + for j := 0; j < len(parr); j++ { + p := parr[j] + // text part + if t := p.Get("text"); t.Exists() { + msg := `{"type":"message","role":"","content":[]}` + msg, _ = sjson.Set(msg, "role", role) + partType := "input_text" + if role == "assistant" { + partType = "output_text" + } + part := `{}` + part, _ = sjson.Set(part, "type", partType) + part, _ = sjson.Set(part, "text", t.String()) + msg, _ = sjson.SetRaw(msg, "content.-1", part) + out, _ = sjson.SetRaw(out, "input.-1", msg) + continue + } + + // function call from model + if fc := p.Get("functionCall"); fc.Exists() { + fn := `{"type":"function_call"}` + if name := fc.Get("name"); name.Exists() { + fn, _ = sjson.Set(fn, "name", name.String()) + } + if args := fc.Get("args"); args.Exists() { + fn, _ = sjson.Set(fn, "arguments", args.Raw) + } + // generate a paired random call_id and enqueue it so the + // corresponding functionResponse can pop the earliest id + // to preserve ordering when multiple calls are present. + id := genCallID() + fn, _ = sjson.Set(fn, "call_id", id) + pendingCallIDs = append(pendingCallIDs, id) + out, _ = sjson.SetRaw(out, "input.-1", fn) + continue + } + + // function response from user + if fr := p.Get("functionResponse"); fr.Exists() { + fno := `{"type":"function_call_output"}` + // Prefer a string result if present; otherwise embed the raw response as a string + if res := fr.Get("response.result"); res.Exists() { + fno, _ = sjson.Set(fno, "output", res.String()) + } else if resp := fr.Get("response"); resp.Exists() { + fno, _ = sjson.Set(fno, "output", resp.Raw) + } + // fno, _ = sjson.Set(fno, "call_id", "call_W6nRJzFXyPM2LFBbfo98qAbq") + // attach the oldest queued call_id to pair the response + // with its call. If the queue is empty, generate a new id. + var id string + if len(pendingCallIDs) > 0 { + id = pendingCallIDs[0] + // pop the first element + pendingCallIDs = pendingCallIDs[1:] + } else { + id = genCallID() + } + fno, _ = sjson.Set(fno, "call_id", id) + out, _ = sjson.SetRaw(out, "input.-1", fno) + continue + } + } + } + } + + // Tools mapping: Gemini functionDeclarations -> Codex tools + tools := root.Get("tools") + if tools.IsArray() { + out, _ = sjson.SetRaw(out, "tools", `[]`) + out, _ = sjson.Set(out, "tool_choice", "auto") + tarr := tools.Array() + for i := 0; i < len(tarr); i++ { + td := tarr[i] + fns := td.Get("functionDeclarations") + if !fns.IsArray() { + continue + } + farr := fns.Array() + for j := 0; j < len(farr); j++ { + fn := farr[j] + tool := `{}` + tool, _ = sjson.Set(tool, "type", "function") + if v := fn.Get("name"); v.Exists() { + tool, _ = sjson.Set(tool, "name", v.String()) + } + if v := fn.Get("description"); v.Exists() { + tool, _ = sjson.Set(tool, "description", v.String()) + } + if prm := fn.Get("parameters"); prm.Exists() { + // Remove optional $schema field if present + cleaned := prm.Raw + cleaned, _ = sjson.Delete(cleaned, "$schema") + cleaned, _ = sjson.Set(cleaned, "additionalProperties", false) + tool, _ = sjson.SetRaw(tool, "parameters", cleaned) + } + tool, _ = sjson.Set(tool, "strict", false) + out, _ = sjson.SetRaw(out, "tools.-1", tool) + } + } + } + + // Fixed flags aligning with Codex expectations + out, _ = sjson.Set(out, "parallel_tool_calls", true) + out, _ = sjson.Set(out, "reasoning.effort", "low") + out, _ = sjson.Set(out, "reasoning.summary", "auto") + out, _ = sjson.Set(out, "stream", true) + out, _ = sjson.Set(out, "store", false) + out, _ = sjson.Set(out, "include", []string{"reasoning.encrypted_content"}) + + return out +} diff --git a/internal/translator/codex/gemini/codex_gemini_response.go b/internal/translator/codex/gemini/codex_gemini_response.go new file mode 100644 index 00000000..8b3f1840 --- /dev/null +++ b/internal/translator/codex/gemini/codex_gemini_response.go @@ -0,0 +1,251 @@ +// Package code provides response translation functionality for Gemini API. +// This package handles the conversion of Codex backend responses into Gemini-compatible +// JSON format, transforming streaming events into single-line JSON responses that include +// thinking content, regular text content, and function calls in the format expected by +// Gemini API clients. +package code + +import ( + "encoding/json" + "time" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +type ConvertCodexResponseToGeminiParams struct { + Model string + CreatedAt int64 + ResponseID string + LastStorageOutput string +} + +// ConvertCodexResponseToGemini converts Codex streaming response format to Gemini single-line JSON format. +// This function processes various Codex event types and transforms them into Gemini-compatible JSON responses. +// It handles thinking content, regular text content, and function calls, outputting single-line JSON +// that matches the Gemini API response format. +// The lastEventType parameter tracks the previous event type to handle consecutive function calls properly. +func ConvertCodexResponseToGemini(rawJSON []byte, param *ConvertCodexResponseToGeminiParams) []string { + rootResult := gjson.ParseBytes(rawJSON) + typeResult := rootResult.Get("type") + typeStr := typeResult.String() + + // Base Gemini response template + template := `{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}` + if param.LastStorageOutput != "" && typeStr == "response.output_item.done" { + template = param.LastStorageOutput + } else { + template, _ = sjson.Set(template, "modelVersion", param.Model) + createdAtResult := rootResult.Get("response.created_at") + if createdAtResult.Exists() { + param.CreatedAt = createdAtResult.Int() + template, _ = sjson.Set(template, "createTime", time.Unix(param.CreatedAt, 0).Format(time.RFC3339Nano)) + } + template, _ = sjson.Set(template, "responseId", param.ResponseID) + } + + // Handle function call completion + if typeStr == "response.output_item.done" { + itemResult := rootResult.Get("item") + itemType := itemResult.Get("type").String() + if itemType == "function_call" { + // Create function call part + functionCall := `{"functionCall":{"name":"","args":{}}}` + functionCall, _ = sjson.Set(functionCall, "functionCall.name", itemResult.Get("name").String()) + + // Parse and set arguments + argsStr := itemResult.Get("arguments").String() + if argsStr != "" { + argsResult := gjson.Parse(argsStr) + if argsResult.IsObject() { + functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr) + } + } + + template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", functionCall) + template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP") + + param.LastStorageOutput = template + + // Use this return to storage message + return []string{} + } + } + + if typeStr == "response.created" { // Handle response creation - set model and response ID + template, _ = sjson.Set(template, "modelVersion", rootResult.Get("response.model").String()) + template, _ = sjson.Set(template, "responseId", rootResult.Get("response.id").String()) + param.ResponseID = rootResult.Get("response.id").String() + } else if typeStr == "response.reasoning_summary_text.delta" { // Handle reasoning/thinking content delta + part := `{"thought":true,"text":""}` + part, _ = sjson.Set(part, "text", rootResult.Get("delta").String()) + template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part) + } else if typeStr == "response.output_text.delta" { // Handle regular text content delta + part := `{"text":""}` + part, _ = sjson.Set(part, "text", rootResult.Get("delta").String()) + template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part) + } else if typeStr == "response.completed" { // Handle response completion with usage metadata + template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int()) + template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int()) + totalTokens := rootResult.Get("response.usage.input_tokens").Int() + rootResult.Get("response.usage.output_tokens").Int() + template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", totalTokens) + } else { + return []string{} + } + + if param.LastStorageOutput != "" { + return []string{param.LastStorageOutput, template} + } else { + return []string{template} + } + +} + +// ConvertCodexResponseToGeminiNonStream converts a completed Codex response to Gemini non-streaming format. +// This function processes the final response.completed event and transforms it into a complete +// Gemini-compatible JSON response that includes all content parts, function calls, and usage metadata. +func ConvertCodexResponseToGeminiNonStream(rawJSON []byte, model string) string { + rootResult := gjson.ParseBytes(rawJSON) + + // Verify this is a response.completed event + if rootResult.Get("type").String() != "response.completed" { + return "" + } + + // Base Gemini response template for non-streaming + template := `{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}` + + // Set model version + template, _ = sjson.Set(template, "modelVersion", model) + + // Set response metadata from the completed response + responseData := rootResult.Get("response") + if responseData.Exists() { + // Set response ID + if responseId := responseData.Get("id"); responseId.Exists() { + template, _ = sjson.Set(template, "responseId", responseId.String()) + } + + // Set creation time + if createdAt := responseData.Get("created_at"); createdAt.Exists() { + template, _ = sjson.Set(template, "createTime", time.Unix(createdAt.Int(), 0).Format(time.RFC3339Nano)) + } + + // Set usage metadata + if usage := responseData.Get("usage"); usage.Exists() { + inputTokens := usage.Get("input_tokens").Int() + outputTokens := usage.Get("output_tokens").Int() + totalTokens := inputTokens + outputTokens + + template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", inputTokens) + template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", outputTokens) + template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", totalTokens) + } + + // Process output content to build parts array + var parts []interface{} + hasToolCall := false + var pendingFunctionCalls []interface{} + + flushPendingFunctionCalls := func() { + if len(pendingFunctionCalls) > 0 { + // Add all pending function calls as individual parts + // This maintains the original Gemini API format while ensuring consecutive calls are grouped together + for _, fc := range pendingFunctionCalls { + parts = append(parts, fc) + } + pendingFunctionCalls = nil + } + } + + if output := responseData.Get("output"); output.Exists() && output.IsArray() { + output.ForEach(func(key, value gjson.Result) bool { + itemType := value.Get("type").String() + + switch itemType { + case "reasoning": + // Flush any pending function calls before adding non-function content + flushPendingFunctionCalls() + + // Add thinking content + if content := value.Get("content"); content.Exists() { + part := map[string]interface{}{ + "thought": true, + "text": content.String(), + } + parts = append(parts, part) + } + + case "message": + // Flush any pending function calls before adding non-function content + flushPendingFunctionCalls() + + // Add regular text content + if content := value.Get("content"); content.Exists() && content.IsArray() { + content.ForEach(func(_, contentItem gjson.Result) bool { + if contentItem.Get("type").String() == "output_text" { + if text := contentItem.Get("text"); text.Exists() { + part := map[string]interface{}{ + "text": text.String(), + } + parts = append(parts, part) + } + } + return true + }) + } + + case "function_call": + // Collect function call for potential merging with consecutive ones + hasToolCall = true + functionCall := map[string]interface{}{ + "functionCall": map[string]interface{}{ + "name": value.Get("name").String(), + "args": map[string]interface{}{}, + }, + } + + // Parse and set arguments + if argsStr := value.Get("arguments").String(); argsStr != "" { + argsResult := gjson.Parse(argsStr) + if argsResult.IsObject() { + var args map[string]interface{} + if err := json.Unmarshal([]byte(argsStr), &args); err == nil { + functionCall["functionCall"].(map[string]interface{})["args"] = args + } + } + } + + pendingFunctionCalls = append(pendingFunctionCalls, functionCall) + } + return true + }) + + // Handle any remaining pending function calls at the end + flushPendingFunctionCalls() + } + + // Set the parts array + if len(parts) > 0 { + template, _ = sjson.SetRaw(template, "candidates.0.content.parts", mustMarshalJSON(parts)) + } + + // Set finish reason based on whether there were tool calls + if hasToolCall { + template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP") + } else { + template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP") + } + } + + return template +} + +// mustMarshalJSON marshals data to JSON, panicking on error (should not happen with valid data) +func mustMarshalJSON(v interface{}) string { + data, err := json.Marshal(v) + if err != nil { + panic(err) + } + return string(data) +} diff --git a/internal/translator/codex/openai/codex_openai_request.go b/internal/translator/codex/openai/codex_openai_request.go new file mode 100644 index 00000000..c03977f7 --- /dev/null +++ b/internal/translator/codex/openai/codex_openai_request.go @@ -0,0 +1,227 @@ +// Package codex provides utilities to translate OpenAI Chat Completions +// request JSON into OpenAI Responses API request JSON using gjson/sjson. +// It supports tools, multimodal text/image inputs, and Structured Outputs. +package openai + +import ( + "github.com/luispater/CLIProxyAPI/internal/misc" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// ConvertOpenAIChatRequestToCodex converts an OpenAI Chat Completions request JSON +// into an OpenAI Responses API request JSON. The transformation follows the +// examples defined in docs/2.md exactly, including tools, multi-turn dialog, +// multimodal text/image handling, and Structured Outputs mapping. +func ConvertOpenAIChatRequestToCodex(rawJSON []byte) string { + // Start with empty JSON object + out := `{}` + store := false + + // Stream must be set to true + if v := gjson.GetBytes(rawJSON, "stream"); v.Exists() { + out, _ = sjson.Set(out, "stream", true) + } + + // Codex not support temperature, top_p, top_k, max_output_tokens, so comment them + // if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() { + // out, _ = sjson.Set(out, "temperature", v.Value()) + // } + // if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() { + // out, _ = sjson.Set(out, "top_p", v.Value()) + // } + // if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() { + // out, _ = sjson.Set(out, "top_k", v.Value()) + // } + + // Map token limits + // if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() { + // out, _ = sjson.Set(out, "max_output_tokens", v.Value()) + // } + // if v := gjson.GetBytes(rawJSON, "max_completion_tokens"); v.Exists() { + // out, _ = sjson.Set(out, "max_output_tokens", v.Value()) + // } + + // Map reasoning effort + if v := gjson.GetBytes(rawJSON, "reasoning_effort"); v.Exists() { + out, _ = sjson.Set(out, "reasoning.effort", v.Value()) + out, _ = sjson.Set(out, "reasoning.summary", "auto") + } + + // Model + if v := gjson.GetBytes(rawJSON, "model"); v.Exists() { + out, _ = sjson.Set(out, "model", v.Value()) + } + + // Extract system instructions from first system message (string or text object) + messages := gjson.GetBytes(rawJSON, "messages") + instructions := misc.CodexInstructions + out, _ = sjson.SetRaw(out, "instructions", instructions) + // if messages.IsArray() { + // arr := messages.Array() + // for i := 0; i < len(arr); i++ { + // m := arr[i] + // if m.Get("role").String() == "system" { + // c := m.Get("content") + // if c.Type == gjson.String { + // out, _ = sjson.Set(out, "instructions", c.String()) + // } else if c.IsObject() && c.Get("type").String() == "text" { + // out, _ = sjson.Set(out, "instructions", c.Get("text").String()) + // } + // break + // } + // } + // } + + // Build input from messages, skipping system/tool roles + out, _ = sjson.SetRaw(out, "input", `[]`) + if messages.IsArray() { + arr := messages.Array() + for i := 0; i < len(arr); i++ { + m := arr[i] + role := m.Get("role").String() + if role == "tool" || role == "function" { + continue + } + + // Prepare message object + msg := `{}` + if role == "system" { + msg, _ = sjson.Set(msg, "role", "user") + } else { + msg, _ = sjson.Set(msg, "role", role) + } + + msg, _ = sjson.SetRaw(msg, "content", `[]`) + + c := m.Get("content") + if c.Type == gjson.String { + // Single string content + partType := "input_text" + if role == "assistant" { + partType = "output_text" + } + part := `{}` + part, _ = sjson.Set(part, "type", partType) + part, _ = sjson.Set(part, "text", c.String()) + msg, _ = sjson.SetRaw(msg, "content.-1", part) + } else if c.IsArray() { + items := c.Array() + for j := 0; j < len(items); j++ { + it := items[j] + t := it.Get("type").String() + switch t { + case "text": + partType := "input_text" + if role == "assistant" { + partType = "output_text" + } + part := `{}` + part, _ = sjson.Set(part, "type", partType) + part, _ = sjson.Set(part, "text", it.Get("text").String()) + msg, _ = sjson.SetRaw(msg, "content.-1", part) + case "image_url": + // Map image inputs to input_image for Responses API + if role == "user" { + part := `{}` + part, _ = sjson.Set(part, "type", "input_image") + if u := it.Get("image_url.url"); u.Exists() { + part, _ = sjson.Set(part, "image_url", u.String()) + } + msg, _ = sjson.SetRaw(msg, "content.-1", part) + } + case "file": + // Files are not specified in examples; skip for now + } + } + } + + out, _ = sjson.SetRaw(out, "input.-1", msg) + } + } + + // Map response_format and text settings to Responses API text.format + rf := gjson.GetBytes(rawJSON, "response_format") + text := gjson.GetBytes(rawJSON, "text") + if rf.Exists() { + // Always create text object when response_format provided + if !gjson.Get(out, "text").Exists() { + out, _ = sjson.SetRaw(out, "text", `{}`) + } + + rft := rf.Get("type").String() + switch rft { + case "text": + out, _ = sjson.Set(out, "text.format.type", "text") + case "json_schema": + js := rf.Get("json_schema") + if js.Exists() { + out, _ = sjson.Set(out, "text.format.type", "json_schema") + if v := js.Get("name"); v.Exists() { + out, _ = sjson.Set(out, "text.format.name", v.Value()) + } + if v := js.Get("strict"); v.Exists() { + out, _ = sjson.Set(out, "text.format.strict", v.Value()) + } + if v := js.Get("schema"); v.Exists() { + out, _ = sjson.SetRaw(out, "text.format.schema", v.Raw) + } + } + } + + // Map verbosity if provided + if text.Exists() { + if v := text.Get("verbosity"); v.Exists() { + out, _ = sjson.Set(out, "text.verbosity", v.Value()) + } + } + + // The examples include store: true when response_format is provided + store = true + } else if text.Exists() { + // If only text.verbosity present (no response_format), map verbosity + if v := text.Get("verbosity"); v.Exists() { + if !gjson.Get(out, "text").Exists() { + out, _ = sjson.SetRaw(out, "text", `{}`) + } + out, _ = sjson.Set(out, "text.verbosity", v.Value()) + } + } + + // Map tools (flatten function fields) + tools := gjson.GetBytes(rawJSON, "tools") + if tools.IsArray() { + out, _ = sjson.SetRaw(out, "tools", `[]`) + arr := tools.Array() + for i := 0; i < len(arr); i++ { + t := arr[i] + if t.Get("type").String() == "function" { + item := `{}` + item, _ = sjson.Set(item, "type", "function") + fn := t.Get("function") + if fn.Exists() { + if v := fn.Get("name"); v.Exists() { + item, _ = sjson.Set(item, "name", v.Value()) + } + if v := fn.Get("description"); v.Exists() { + item, _ = sjson.Set(item, "description", v.Value()) + } + if v := fn.Get("parameters"); v.Exists() { + item, _ = sjson.SetRaw(item, "parameters", v.Raw) + } + if v := fn.Get("strict"); v.Exists() { + item, _ = sjson.Set(item, "strict", v.Value()) + } + } + out, _ = sjson.SetRaw(out, "tools.-1", item) + } + } + // The examples include store: true when tools and formatting are used; be conservative + if rf.Exists() { + store = true + } + } + + out, _ = sjson.Set(out, "store", store) + return out +} diff --git a/internal/translator/codex/openai/codex_openai_response.go b/internal/translator/codex/openai/codex_openai_response.go new file mode 100644 index 00000000..b7217f94 --- /dev/null +++ b/internal/translator/codex/openai/codex_openai_response.go @@ -0,0 +1,231 @@ +// Package codex provides response translation functionality for converting between +// Codex API response formats and OpenAI-compatible formats. It handles both +// streaming and non-streaming responses, transforming backend client responses +// into OpenAI Server-Sent Events (SSE) format and standard JSON response formats. +// The package supports content translation, function calls, reasoning content, +// usage metadata, and various response attributes while maintaining compatibility +// with OpenAI API specifications. +package openai + +import ( + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +type ConvertCliToOpenAIParams struct { + ResponseID string + CreatedAt int64 + Model string +} + +// ConvertCodexResponseToOpenAIChat translates a single chunk of a streaming response from the +// Codex backend client format to the OpenAI Server-Sent Events (SSE) format. +// It returns an empty string if the chunk contains no useful data. +func ConvertCodexResponseToOpenAIChat(rawJSON []byte, params *ConvertCliToOpenAIParams) (*ConvertCliToOpenAIParams, string) { + // Initialize the OpenAI SSE template. + template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` + + rootResult := gjson.ParseBytes(rawJSON) + + typeResult := rootResult.Get("type") + dataType := typeResult.String() + if dataType == "response.created" { + return &ConvertCliToOpenAIParams{ + ResponseID: rootResult.Get("response.id").String(), + CreatedAt: rootResult.Get("response.created_at").Int(), + Model: rootResult.Get("response.model").String(), + }, "" + } + + if params == nil { + return params, "" + } + + // Extract and set the model version. + if modelResult := gjson.GetBytes(rawJSON, "model"); modelResult.Exists() { + template, _ = sjson.Set(template, "model", modelResult.String()) + } + + template, _ = sjson.Set(template, "created", params.CreatedAt) + + // Extract and set the response ID. + template, _ = sjson.Set(template, "id", params.ResponseID) + + // Extract and set usage metadata (token counts). + if usageResult := gjson.GetBytes(rawJSON, "response.usage"); usageResult.Exists() { + if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() { + template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int()) + } + if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() { + template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int()) + } + if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() { + template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int()) + } + if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() { + template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int()) + } + } + + if dataType == "response.reasoning_summary_text.delta" { + if deltaResult := rootResult.Get("delta"); deltaResult.Exists() { + template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") + template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", deltaResult.String()) + } + } else if dataType == "response.reasoning_summary_text.done" { + template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") + template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", "\n\n") + } else if dataType == "response.output_text.delta" { + if deltaResult := rootResult.Get("delta"); deltaResult.Exists() { + template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") + template, _ = sjson.Set(template, "choices.0.delta.content", deltaResult.String()) + } + } else if dataType == "response.completed" { + template, _ = sjson.Set(template, "choices.0.finish_reason", "stop") + template, _ = sjson.Set(template, "choices.0.native_finish_reason", "stop") + } else if dataType == "response.output_item.done" { + functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}` + itemResult := rootResult.Get("item") + if itemResult.Exists() { + if itemResult.Get("type").String() != "function_call" { + return params, "" + } + template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`) + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String()) + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", itemResult.Get("name").String()) + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String()) + template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") + template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate) + } + + } else { + return params, "" + } + + return params, template +} + +// ConvertCodexResponseToOpenAIChatNonStream aggregates response from the Codex backend client +// convert a single, non-streaming OpenAI-compatible JSON response. +func ConvertCodexResponseToOpenAIChatNonStream(rawJSON string, unixTimestamp int64) string { + template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` + + // Extract and set the model version. + if modelResult := gjson.Get(rawJSON, "model"); modelResult.Exists() { + template, _ = sjson.Set(template, "model", modelResult.String()) + } + + // Extract and set the creation timestamp. + if createdAtResult := gjson.Get(rawJSON, "created_at"); createdAtResult.Exists() { + template, _ = sjson.Set(template, "created", createdAtResult.Int()) + } else { + template, _ = sjson.Set(template, "created", unixTimestamp) + } + + // Extract and set the response ID. + if idResult := gjson.Get(rawJSON, "id"); idResult.Exists() { + template, _ = sjson.Set(template, "id", idResult.String()) + } + + // Extract and set usage metadata (token counts). + if usageResult := gjson.Get(rawJSON, "usage"); usageResult.Exists() { + if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() { + template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int()) + } + if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() { + template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int()) + } + if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() { + template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int()) + } + if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() { + template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int()) + } + } + + // Process the output array for content and function calls + outputResult := gjson.Get(rawJSON, "output") + if outputResult.IsArray() { + outputArray := outputResult.Array() + var contentText string + var reasoningText string + var toolCalls []string + + for _, outputItem := range outputArray { + outputType := outputItem.Get("type").String() + + switch outputType { + case "reasoning": + // Extract reasoning content from summary + if summaryResult := outputItem.Get("summary"); summaryResult.IsArray() { + summaryArray := summaryResult.Array() + for _, summaryItem := range summaryArray { + if summaryItem.Get("type").String() == "summary_text" { + reasoningText = summaryItem.Get("text").String() + break + } + } + } + case "message": + // Extract message content + if contentResult := outputItem.Get("content"); contentResult.IsArray() { + contentArray := contentResult.Array() + for _, contentItem := range contentArray { + if contentItem.Get("type").String() == "output_text" { + contentText = contentItem.Get("text").String() + break + } + } + } + case "function_call": + // Handle function call content + functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}` + + if callIdResult := outputItem.Get("call_id"); callIdResult.Exists() { + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", callIdResult.String()) + } + + if nameResult := outputItem.Get("name"); nameResult.Exists() { + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", nameResult.String()) + } + + if argsResult := outputItem.Get("arguments"); argsResult.Exists() { + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", argsResult.String()) + } + + toolCalls = append(toolCalls, functionCallTemplate) + } + } + + // Set content and reasoning content if found + if contentText != "" { + template, _ = sjson.Set(template, "choices.0.message.content", contentText) + template, _ = sjson.Set(template, "choices.0.message.role", "assistant") + } + + if reasoningText != "" { + template, _ = sjson.Set(template, "choices.0.message.reasoning_content", reasoningText) + template, _ = sjson.Set(template, "choices.0.message.role", "assistant") + } + + // Add tool calls if any + if len(toolCalls) > 0 { + template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`) + for _, toolCall := range toolCalls { + template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", toolCall) + } + template, _ = sjson.Set(template, "choices.0.message.role", "assistant") + } + } + + // Extract and set the finish reason based on status + if statusResult := gjson.Get(rawJSON, "status"); statusResult.Exists() { + status := statusResult.String() + if status == "completed" { + template, _ = sjson.Set(template, "choices.0.finish_reason", "stop") + template, _ = sjson.Set(template, "choices.0.native_finish_reason", "stop") + } + } + + return template +} diff --git a/internal/api/translator/claude/code/request.go b/internal/translator/gemini-cli/claude/code/cli_cc_request.go similarity index 96% rename from internal/api/translator/claude/code/request.go rename to internal/translator/gemini-cli/claude/code/cli_cc_request.go index 4fe924af..5b23d8a0 100644 --- a/internal/api/translator/claude/code/request.go +++ b/internal/translator/gemini-cli/claude/code/cli_cc_request.go @@ -8,16 +8,17 @@ package code import ( "bytes" "encoding/json" + "strings" + "github.com/luispater/CLIProxyAPI/internal/client" "github.com/tidwall/gjson" "github.com/tidwall/sjson" - "strings" ) -// PrepareClaudeRequest parses and transforms a Claude API request into internal client format. +// ConvertClaudeCodeRequestToCli parses and transforms a Claude API request into internal client format. // It extracts the model name, system instruction, message contents, and tool declarations // from the raw JSON request and returns them in the format expected by the internal client. -func PrepareClaudeRequest(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) { +func ConvertClaudeCodeRequestToCli(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) { var pathsToDelete []string root := gjson.ParseBytes(rawJSON) walk(root, "", "additionalProperties", &pathsToDelete) diff --git a/internal/api/translator/claude/code/response.go b/internal/translator/gemini-cli/claude/code/cli_cc_response.go similarity index 97% rename from internal/api/translator/claude/code/response.go rename to internal/translator/gemini-cli/claude/code/cli_cc_response.go index 3ef5fc2b..a988e8f0 100644 --- a/internal/api/translator/claude/code/response.go +++ b/internal/translator/gemini-cli/claude/code/cli_cc_response.go @@ -9,19 +9,20 @@ package code import ( "bytes" "fmt" + "time" + "github.com/tidwall/gjson" "github.com/tidwall/sjson" - "time" ) -// ConvertCliToClaude performs sophisticated streaming response format conversion. +// ConvertCliResponseToClaudeCode performs sophisticated streaming response format conversion. // This function implements a complex state machine that translates backend client responses // into Claude-compatible Server-Sent Events (SSE) format. It manages different response types // and handles state transitions between content blocks, thinking processes, and function calls. // // Response type states: 0=none, 1=content, 2=thinking, 3=function // The function maintains state across multiple calls to ensure proper SSE event sequencing. -func ConvertCliToClaude(rawJSON []byte, isGlAPIKey, hasFirstResponse bool, responseType, responseIndex *int) string { +func ConvertCliResponseToClaudeCode(rawJSON []byte, isGlAPIKey, hasFirstResponse bool, responseType, responseIndex *int) string { // Normalize the response format for different API key types // Generative Language API keys have a different response structure if isGlAPIKey { diff --git a/internal/api/translator/gemini/cli/request.go b/internal/translator/gemini-cli/gemini/cli/cli_cli_request.go similarity index 99% rename from internal/api/translator/gemini/cli/request.go rename to internal/translator/gemini-cli/gemini/cli/cli_cli_request.go index 460820d0..04b44107 100644 --- a/internal/api/translator/gemini/cli/request.go +++ b/internal/translator/gemini-cli/gemini/cli/cli_cli_request.go @@ -9,6 +9,7 @@ package cli import ( "encoding/json" "fmt" + log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/tidwall/sjson" diff --git a/internal/api/translator/openai/request.go b/internal/translator/gemini-cli/openai/cli_openai_request.go similarity index 90% rename from internal/api/translator/openai/request.go rename to internal/translator/gemini-cli/openai/cli_openai_request.go index 7251ca9e..fd7fe136 100644 --- a/internal/api/translator/openai/request.go +++ b/internal/translator/gemini-cli/openai/cli_openai_request.go @@ -6,18 +6,31 @@ package openai import ( "encoding/json" - "github.com/luispater/CLIProxyAPI/internal/api/translator" "strings" "github.com/luispater/CLIProxyAPI/internal/client" + "github.com/luispater/CLIProxyAPI/internal/misc" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" ) -// PrepareRequest translates a raw JSON request from an OpenAI-compatible format +// ConvertOpenAIChatRequestToCli translates a raw JSON request from an OpenAI-compatible format // to the internal format expected by the backend client. It parses messages, // roles, content types (text, image, file), and tool calls. -func PrepareRequest(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) { +// +// This function handles the complex task of converting between the OpenAI message +// format and the internal format used by the Gemini client. It processes different +// message types (system, user, assistant, tool) and content types (text, images, files). +// +// Parameters: +// - rawJSON: The raw JSON bytes of the OpenAI-compatible request +// +// Returns: +// - string: The model name to use +// - *client.Content: System instruction content (if any) +// - []client.Content: The conversation contents in internal format +// - []client.ToolDeclaration: Tool declarations from the request +func ConvertOpenAIChatRequestToCli(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) { // Extract the model name from the request, defaulting to "gemini-2.5-pro". modelName := "gemini-2.5-pro" modelResult := gjson.GetBytes(rawJSON, "model") @@ -126,7 +139,7 @@ func PrepareRequest(rawJSON []byte) (string, *client.Content, []client.Content, if split := strings.Split(filename, "."); len(split) > 1 { ext = split[len(split)-1] } - if mimeType, ok := translator.MimeTypes[ext]; ok { + if mimeType, ok := misc.MimeTypes[ext]; ok { parts = append(parts, client.Part{InlineData: &client.InlineData{ MimeType: mimeType, Data: fileData, diff --git a/internal/api/translator/openai/response.go b/internal/translator/gemini-cli/openai/cli_openai_response.go similarity index 96% rename from internal/api/translator/openai/response.go rename to internal/translator/gemini-cli/openai/cli_openai_response.go index 67757e29..ff8056c6 100644 --- a/internal/api/translator/openai/response.go +++ b/internal/translator/gemini-cli/openai/cli_openai_response.go @@ -15,10 +15,10 @@ import ( "github.com/tidwall/sjson" ) -// ConvertCliToOpenAI translates a single chunk of a streaming response from the +// ConvertCliResponseToOpenAIChat translates a single chunk of a streaming response from the // backend client format to the OpenAI Server-Sent Events (SSE) format. // It returns an empty string if the chunk contains no useful data. -func ConvertCliToOpenAI(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string { +func ConvertCliResponseToOpenAIChat(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string { if isGlAPIKey { rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON) } @@ -109,9 +109,9 @@ func ConvertCliToOpenAI(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) st return template } -// ConvertCliToOpenAINonStream aggregates response from the backend client +// ConvertCliResponseToOpenAIChatNonStream aggregates response from the backend client // convert a single, non-streaming OpenAI-compatible JSON response. -func ConvertCliToOpenAINonStream(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string { +func ConvertCliResponseToOpenAIChatNonStream(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string { if isGlAPIKey { rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON) } diff --git a/internal/util/provider.go b/internal/util/provider.go new file mode 100644 index 00000000..1670e80c --- /dev/null +++ b/internal/util/provider.go @@ -0,0 +1,24 @@ +// Package util provides utility functions used across the CLIProxyAPI application. +// These functions handle common tasks such as determining AI service providers +// from model names and managing HTTP proxies. +package util + +import ( + "strings" +) + +// GetProviderName determines the AI service provider based on the model name. +// It analyzes the model name string to identify which service provider it belongs to. +// +// Supported providers: +// - "gemini" for Google's Gemini models +// - "gpt" for OpenAI's GPT models +// - "unknow" for unrecognized model names +func GetProviderName(modelName string) string { + if strings.Contains(modelName, "gemini") { + return "gemini" + } else if strings.Contains(modelName, "gpt") { + return "gpt" + } + return "unknow" +} diff --git a/internal/util/proxy.go b/internal/util/proxy.go index dbf80a02..a0a66006 100644 --- a/internal/util/proxy.go +++ b/internal/util/proxy.go @@ -5,17 +5,19 @@ package util import ( "context" - "github.com/luispater/CLIProxyAPI/internal/config" - "golang.org/x/net/proxy" "net" "net/http" "net/url" + + "github.com/luispater/CLIProxyAPI/internal/config" + log "github.com/sirupsen/logrus" + "golang.org/x/net/proxy" ) // SetProxy configures the provided HTTP client with proxy settings from the configuration. // It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport // to route requests through the configured proxy server. -func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error) { +func SetProxy(cfg *config.Config, httpClient *http.Client) *http.Client { var transport *http.Transport proxyURL, errParse := url.Parse(cfg.ProxyURL) if errParse == nil { @@ -25,7 +27,8 @@ func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error) proxyAuth := &proxy.Auth{User: username, Password: password} dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct) if errSOCKS5 != nil { - return nil, errSOCKS5 + log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) + return httpClient } transport = &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -39,5 +42,5 @@ func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error) if transport != nil { httpClient.Transport = transport } - return httpClient, nil + return httpClient } diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 68240140..e7fe2b4e 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -7,12 +7,6 @@ package watcher import ( "context" "encoding/json" - "github.com/fsnotify/fsnotify" - "github.com/luispater/CLIProxyAPI/internal/auth" - "github.com/luispater/CLIProxyAPI/internal/client" - "github.com/luispater/CLIProxyAPI/internal/config" - "github.com/luispater/CLIProxyAPI/internal/util" - log "github.com/sirupsen/logrus" "io/fs" "net/http" "os" @@ -20,6 +14,15 @@ import ( "strings" "sync" "time" + + "github.com/fsnotify/fsnotify" + "github.com/luispater/CLIProxyAPI/internal/auth/codex" + "github.com/luispater/CLIProxyAPI/internal/auth/gemini" + "github.com/luispater/CLIProxyAPI/internal/client" + "github.com/luispater/CLIProxyAPI/internal/config" + "github.com/luispater/CLIProxyAPI/internal/util" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" ) // Watcher manages file watching for configuration and authentication files @@ -27,14 +30,14 @@ type Watcher struct { configPath string authDir string config *config.Config - clients []*client.Client + clients []client.Client clientsMutex sync.RWMutex - reloadCallback func([]*client.Client, *config.Config) + reloadCallback func([]client.Client, *config.Config) watcher *fsnotify.Watcher } // NewWatcher creates a new file watcher instance -func NewWatcher(configPath, authDir string, reloadCallback func([]*client.Client, *config.Config)) (*Watcher, error) { +func NewWatcher(configPath, authDir string, reloadCallback func([]client.Client, *config.Config)) (*Watcher, error) { watcher, errNewWatcher := fsnotify.NewWatcher() if errNewWatcher != nil { return nil, errNewWatcher @@ -83,7 +86,7 @@ func (w *Watcher) SetConfig(cfg *config.Config) { } // SetClients updates the current client list -func (w *Watcher) SetClients(clients []*client.Client) { +func (w *Watcher) SetClients(clients []client.Client) { w.clientsMutex.Lock() defer w.clientsMutex.Unlock() w.clients = clients @@ -193,7 +196,7 @@ func (w *Watcher) reloadClients() { log.Debugf("scanning auth directory: %s", cfg.AuthDir) // Create new client list - newClients := make([]*client.Client, 0) + newClients := make([]client.Client, 0) authFileCount := 0 successfulAuthCount := 0 @@ -209,37 +212,57 @@ func (w *Watcher) reloadClients() { authFileCount++ log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path)) - f, errOpen := os.Open(path) - if errOpen != nil { - log.Errorf("failed to open token file %s: %v", path, errOpen) - return nil // Continue processing other files + data, errReadFile := os.ReadFile(path) + if errReadFile != nil { + return errReadFile + } + + tokenType := "gemini" + typeResult := gjson.GetBytes(data, "type") + if typeResult.Exists() { + tokenType = typeResult.String() } - defer func() { - errClose := f.Close() - if errClose != nil { - log.Errorf("failed to close token file %s: %v", path, errClose) - } - }() // Decode the token storage file - var ts auth.TokenStorage - if errDecode := json.NewDecoder(f).Decode(&ts); errDecode == nil { - // For each valid token, create an authenticated client - clientCtx := context.Background() - log.Debugf(" initializing authentication for token from %s...", filepath.Base(path)) - httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg) - if errGetClient != nil { - log.Errorf(" failed to get authenticated client for token %s: %v", path, errGetClient) - return nil // Continue processing other files - } - log.Debugf(" authentication successful for token from %s", filepath.Base(path)) + if tokenType == "gemini" { + var ts gemini.GeminiTokenStorage + if err = json.Unmarshal(data, &ts); err == nil { + // For each valid token, create an authenticated client + clientCtx := context.Background() + log.Debugf(" initializing gemini authentication for token from %s...", filepath.Base(path)) + geminiAuth := gemini.NewGeminiAuth() + httpClient, errGetClient := geminiAuth.GetAuthenticatedClient(clientCtx, &ts, cfg) + if errGetClient != nil { + log.Errorf(" failed to get authenticated client for token %s: %v", path, errGetClient) + return nil // Continue processing other files + } + log.Debugf(" authentication successful for token from %s", filepath.Base(path)) - // Add the new client to the pool - cliClient := client.NewClient(httpClient, &ts, cfg) - newClients = append(newClients, cliClient) - successfulAuthCount++ - } else { - log.Errorf(" failed to decode token file %s: %v", path, errDecode) + // Add the new client to the pool + cliClient := client.NewGeminiClient(httpClient, &ts, cfg) + newClients = append(newClients, cliClient) + successfulAuthCount++ + } else { + log.Errorf(" failed to decode token file %s: %v", path, err) + } + } else if tokenType == "codex" { + var ts codex.CodexTokenStorage + if err = json.Unmarshal(data, &ts); err == nil { + // For each valid token, create an authenticated client + log.Debugf(" initializing codex authentication for token from %s...", filepath.Base(path)) + codexClient, errGetClient := client.NewCodexClient(cfg, &ts) + if errGetClient != nil { + log.Errorf(" failed to get authenticated client for token %s: %v", path, errGetClient) + return nil // Continue processing other files + } + log.Debugf(" authentication successful for token from %s", filepath.Base(path)) + + // Add the new client to the pool + newClients = append(newClients, codexClient) + successfulAuthCount++ + } else { + log.Errorf(" failed to decode token file %s: %v", path, err) + } } } return nil @@ -256,14 +279,10 @@ func (w *Watcher) reloadClients() { if len(cfg.GlAPIKey) > 0 { log.Debugf("processing %d Generative Language API keys", len(cfg.GlAPIKey)) for i := 0; i < len(cfg.GlAPIKey); i++ { - httpClient, errSetProxy := util.SetProxy(cfg, &http.Client{}) - if errSetProxy != nil { - log.Errorf("set proxy failed for GL API key %d: %v", i+1, errSetProxy) - continue - } + httpClient := util.SetProxy(cfg, &http.Client{}) log.Debugf(" initializing with Generative Language API key %d...", i+1) - cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i]) + cliClient := client.NewGeminiClient(httpClient, nil, cfg, cfg.GlAPIKey[i]) newClients = append(newClients, cliClient) glAPIKeyCount++ }