From aa2f37d54dfa33be2f9a9d8b181e853d58b5582a Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Thu, 21 Aug 2025 05:11:21 +0800 Subject: [PATCH] Add Qwen support --- README.md | 30 +- README_CN.md | 27 +- cmd/server/main.go | 4 + internal/api/handlers/claude/code_handlers.go | 152 ++++++- .../api/handlers/gemini/cli/cli_handlers.go | 182 ++++++++ .../api/handlers/gemini/gemini_handlers.go | 168 ++++++++ internal/api/handlers/handlers.go | 8 + .../api/handlers/openai/openai_handlers.go | 159 +++++++ internal/api/server.go | 2 +- internal/auth/qwen/qwen_auth.go | 337 +++++++++++++++ internal/auth/qwen/qwen_token.go | 61 +++ internal/client/qwen_client.go | 288 +++++++++++++ internal/cmd/qwen_login.go | 85 ++++ internal/cmd/run.go | 23 +- .../openai/claude/openai_claude_request.go | 253 ++++++++++++ .../openai/claude/openai_claude_response.go | 389 ++++++++++++++++++ .../openai/gemini/openai_gemini_request.go | 359 ++++++++++++++++ .../openai/gemini/openai_gemini_response.go | 353 ++++++++++++++++ internal/util/provider.go | 2 + internal/watcher/watcher.go | 15 + 20 files changed, 2888 insertions(+), 9 deletions(-) create mode 100644 internal/auth/qwen/qwen_auth.go create mode 100644 internal/auth/qwen/qwen_token.go create mode 100644 internal/client/qwen_client.go create mode 100644 internal/cmd/qwen_login.go create mode 100644 internal/translator/openai/claude/openai_claude_request.go create mode 100644 internal/translator/openai/claude/openai_claude_response.go create mode 100644 internal/translator/openai/gemini/openai_gemini_request.go create mode 100644 internal/translator/openai/gemini/openai_gemini_response.go diff --git a/README.md b/README.md index fd49930a..771b2b89 100644 --- a/README.md +++ b/README.md @@ -8,19 +8,23 @@ It now also supports OpenAI Codex (GPT models) and Claude Code via OAuth. so you can use local or multi‑account CLI access with OpenAI‑compatible clients and SDKs. +Now, We added the first Chinese provider: [Qwen Code](https://github.com/QwenLM/qwen-code). + ## Features - OpenAI/Gemini/Claude compatible API endpoints for CLI models - OpenAI Codex support (GPT models) via OAuth login - Claude Code support via OAuth login +- Qwen Code support via OAuth login - Streaming and non-streaming responses - Function calling/tools support - Multimodal input support (text and images) -- Multiple accounts with round‑robin load balancing (Gemini, OpenAI, and Claude) -- Simple CLI authentication flows (Gemini, OpenAI, and Claude) +- Multiple accounts with round‑robin load balancing (Gemini, OpenAI, Claude and Qwen) +- Simple CLI authentication flows (Gemini, OpenAI, Claude and Qwen) - Generative Language API Key support - Gemini CLI multi‑account load balancing - Claude Code multi‑account load balancing +- Qwen Code multi‑account load balancing ## Installation @@ -30,6 +34,7 @@ so you can use local or multi‑account CLI access with OpenAI‑compatible clie - A Google account with access to Gemini CLI models (optional) - An OpenAI account for Codex/GPT access (optional) - An Anthropic account for Claude Code access (optional) +- A Qwen Chat account for Qwen Code access (optional) ### Building from Source @@ -60,6 +65,8 @@ You can authenticate for Gemini, OpenAI, and/or Claude. All can coexist in the s ``` The local OAuth callback uses port `8085`. + Options: add `--no-browser` to print the login URL instead of opening a browser. The local OAuth callback uses port `1455`. + - OpenAI (Codex/GPT via OAuth): ```bash ./cli-proxy-api --codex-login @@ -72,6 +79,13 @@ You can authenticate for Gemini, OpenAI, and/or Claude. All can coexist in the s ``` Options: add `--no-browser` to print the login URL instead of opening a browser. The local OAuth callback uses port `54545`. +- Qwen (Qwen Chat via OAuth): + ```bash + ./cli-proxy-api --qwen-login + ``` + Options: add `--no-browser` to print the login URL instead of opening a browser. Use the Qwen Chat's OAuth device flow. + + ### Starting the Server Once authenticated, start the server: @@ -112,7 +126,7 @@ Request body example: ``` Notes: -- Use a `gemini-*` model for Gemini (e.g., `gemini-2.5-pro`), a `gpt-*` model for OpenAI (e.g., `gpt-5`), or a `claude-*` model for Claude (e.g., `claude-3-5-sonnet-20241022`). The proxy will route to the correct provider automatically. +- Use a `gemini-*` model for Gemini (e.g., `gemini-2.5-pro`), a `gpt-*` model for OpenAI (e.g., `gpt-5`), a `claude-*` model for Claude (e.g., `claude-3-5-sonnet-20241022`), or a `qwen-*` model for Qwen (e.g., `qwen3-coder-plus`). The proxy will route to the correct provider automatically. #### Claude Messages (SSE-compatible) @@ -210,6 +224,8 @@ console.log(await claudeResponse.json()); - claude-sonnet-4-20250514 - claude-3-7-sonnet-20250219 - claude-3-5-haiku-20241022 +- qwen3-coder-plus +- qwen3-coder-flash - Gemini models auto‑switch to preview variants when needed ## Configuration @@ -338,6 +354,14 @@ export ANTHROPIC_MODEL=claude-sonnet-4-20250514 export ANTHROPIC_SMALL_FAST_MODEL=claude-3-5-haiku-20241022 ``` +Using Claude models: +```bash +export ANTHROPIC_BASE_URL=http://127.0.0.1:8317 +export ANTHROPIC_AUTH_TOKEN=sk-dummy +export ANTHROPIC_MODEL=qwen3-coder-plus +export ANTHROPIC_SMALL_FAST_MODEL=qwen3-coder-flash +``` + ## Run with Docker Run the following command to login (Gemini OAuth on port 8085): diff --git a/README_CN.md b/README_CN.md index 18885780..c10123bf 100644 --- a/README_CN.md +++ b/README_CN.md @@ -8,19 +8,23 @@ 可与本地或多账户方式配合,使用任何 OpenAI 兼容的客户端与 SDK。 +现在,我们添加了第一个中国提供商:[Qwen Code](https://github.com/QwenLM/qwen-code)。 + ## 功能特性 - 为 CLI 模型提供 OpenAI/Gemini/Claude 兼容的 API 端点 - 新增 OpenAI Codex(GPT 系列)支持(OAuth 登录) - 新增 Claude Code 支持(OAuth 登录) +- 新增 Qwen Code 支持(OAuth 登录) - 支持流式与非流式响应 - 函数调用/工具支持 - 多模态输入(文本、图片) -- 多账户支持与轮询负载均衡(Gemini、OpenAI 与 Claude) -- 简单的 CLI 身份验证流程(Gemini、OpenAI 与 Claude) +- 多账户支持与轮询负载均衡(Gemini、OpenAI、Claude 与 Qwen) +- 简单的 CLI 身份验证流程(Gemini、OpenAI、Claude 与 Qwen) - 支持 Gemini AIStudio API 密钥 - 支持 Gemini CLI 多账户轮询 - 支持 Claude Code 多账户轮询 +- 支持 Qwen Code 多账户轮询 ## 安装 @@ -30,6 +34,7 @@ - 有权访问 Gemini CLI 模型的 Google 账户(可选) - 有权访问 OpenAI Codex/GPT 的 OpenAI 账户(可选) - 有权访问 Claude Code 的 Anthropic 账户(可选) +- 有权访问 Qwen Code 的 Qwen Chat 账户(可选) ### 从源码构建 @@ -72,6 +77,12 @@ ``` 选项:加上 `--no-browser` 可打印登录地址而不自动打开浏览器。本地 OAuth 回调端口为 `54545`。 +- Qwen(Qwen Chat,OAuth): + ```bash + ./cli-proxy-api --qwen-login + ``` + 选项:加上 `--no-browser` 可打印登录地址而不自动打开浏览器。使用 Qwen Chat 的 OAuth 设备登录流程。 + ### 启动服务器 身份验证完成后,启动服务器: @@ -112,7 +123,7 @@ POST http://localhost:8317/v1/chat/completions ``` 说明: -- 使用 `gemini-*` 模型(如 `gemini-2.5-pro`)走 Gemini,使用 `gpt-*` 模型(如 `gpt-5`)走 OpenAI,使用 `claude-*` 模型(如 `claude-3-5-sonnet-20241022`)走 Claude,服务会自动路由到对应提供商。 +- 使用 `gemini-*` 模型(如 `gemini-2.5-pro`)走 Gemini,使用 `gpt-*` 模型(如 `gpt-5`)走 OpenAI,使用 `claude-*` 模型(如 `claude-3-5-sonnet-20241022`)走 Claude,使用 `qwen-*` 模型(如 `qwen3-coder-plus`)走 Qwen,服务会自动路由到对应提供商。 #### Claude 消息(SSE 兼容) @@ -210,6 +221,8 @@ console.log(await claudeResponse.json()); - claude-sonnet-4-20250514 - claude-3-7-sonnet-20250219 - claude-3-5-haiku-20241022 +- qwen3-coder-plus +- qwen3-coder-flash - Gemini 模型在需要时自动切换到对应的 preview 版本 ## 配置 @@ -338,6 +351,14 @@ export ANTHROPIC_MODEL=claude-sonnet-4-20250514 export ANTHROPIC_SMALL_FAST_MODEL=claude-3-5-haiku-20241022 ``` +使用 Qwen 模型: +```bash +export ANTHROPIC_BASE_URL=http://127.0.0.1:8317 +export ANTHROPIC_AUTH_TOKEN=sk-dummy +export ANTHROPIC_MODEL=qwen3-coder-plus +export ANTHROPIC_SMALL_FAST_MODEL=qwen3-coder-flash +``` + ## 使用 Docker 运行 diff --git a/cmd/server/main.go b/cmd/server/main.go index 22e03f9e..ec700aeb 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -60,6 +60,7 @@ func main() { var login bool var codexLogin bool var claudeLogin bool + var qwenLogin bool var noBrowser bool var projectID string var configPath string @@ -68,6 +69,7 @@ func main() { flag.BoolVar(&login, "login", false, "Login Google Account") flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth") flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth") + flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth") flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth") flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)") flag.StringVar(&configPath, "config", "", "Configure File Path") @@ -132,6 +134,8 @@ func main() { } else if claudeLogin { // Handle Claude login cmd.DoClaudeLogin(cfg, options) + } else if qwenLogin { + cmd.DoQwenLogin(cfg, options) } else { // Start the main proxy service cmd.StartService(cfg, configFilePath) diff --git a/internal/api/handlers/claude/code_handlers.go b/internal/api/handlers/claude/code_handlers.go index 797ef243..2124e0c5 100644 --- a/internal/api/handlers/claude/code_handlers.go +++ b/internal/api/handlers/claude/code_handlers.go @@ -11,6 +11,7 @@ import ( "context" "fmt" "net/http" + "strings" "time" "github.com/gin-gonic/gin" @@ -18,6 +19,7 @@ import ( "github.com/luispater/CLIProxyAPI/internal/client" translatorClaudeCodeToCodex "github.com/luispater/CLIProxyAPI/internal/translator/codex/claude/code" translatorClaudeCodeToGeminiCli "github.com/luispater/CLIProxyAPI/internal/translator/gemini-cli/claude/code" + translatorClaudeCodeToQwen "github.com/luispater/CLIProxyAPI/internal/translator/openai/claude" "github.com/luispater/CLIProxyAPI/internal/util" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -62,7 +64,7 @@ func (h *ClaudeCodeAPIHandlers) ClaudeMessages(c *gin.Context) { // Check if the client requested a streaming response. streamResult := gjson.GetBytes(rawJSON, "stream") - if streamResult.Type == gjson.False { + if !streamResult.Exists() || streamResult.Type == gjson.False { return } @@ -72,6 +74,8 @@ func (h *ClaudeCodeAPIHandlers) ClaudeMessages(c *gin.Context) { h.handleCodexStreamingResponse(c, rawJSON) } else if provider == "claude" { h.handleClaudeStreamingResponse(c, rawJSON) + } else if provider == "qwen" { + h.handleQwenStreamingResponse(c, rawJSON) } else { h.handleGeminiStreamingResponse(c, rawJSON) } @@ -518,3 +522,149 @@ outLoop: } } } + +// handleQwenStreamingResponse streams Claude-compatible responses backed by OpenAI. +// It converts the Claude request into Qwen responses format, establishes SSE, +// and translates streaming chunks back into Claude Code events. +func (h *ClaudeCodeAPIHandlers) handleQwenStreamingResponse(c *gin.Context, rawJSON []byte) { + // Set up Server-Sent Events (SSE) headers for streaming response + // These headers are essential for maintaining a persistent connection + // and enabling real-time streaming of chat completions + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + + // Get the http.Flusher interface to manually flush the response. + // This is crucial for streaming as it allows immediate sending of data chunks + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + // Parse and prepare the Claude request, extracting model name, system instructions, + // conversation contents, and available tools from the raw JSON + newRequestJSON := translatorClaudeCodeToQwen.ConvertAnthropicRequestToOpenAI(rawJSON) + modelName := gjson.GetBytes(rawJSON, "model").String() + + newRequestJSON, _ = sjson.Set(newRequestJSON, "model", modelName) + // log.Debugf(string(rawJSON)) + // log.Debugf(newRequestJSON) + // return + // Create a cancellable context for the backend client request + // This allows proper cleanup and cancellation of ongoing requests + cliCtx, cliCancel := h.GetContextWithCancel(c, context.Background()) + + var cliClient client.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + // This prevents deadlocks and ensures proper resource cleanup + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + + // Main client rotation loop with quota management + // This loop implements a sophisticated load balancing and failover mechanism +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error()) + flusher.Flush() + cliCancel() + return + } + + log.Debugf("Request use qwen account: %s", cliClient.GetEmail()) + + // Initiate streaming communication with the backend client + // This returns two channels: one for response chunks and one for errors + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "") + + // Track response state for proper Claude format conversion + + params := &translatorClaudeCodeToQwen.ConvertOpenAIResponseToAnthropicParams{ + MessageID: "", + Model: "", + CreatedAt: 0, + ContentAccumulator: strings.Builder{}, + ToolCallsAccumulator: nil, + } + + // Main streaming loop - handles multiple concurrent events using Go channels + // This select statement manages four different types of events simultaneously + for { + select { + // Case 1: Handle client disconnection + // Detects when the HTTP client has disconnected and cleans up resources + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request to prevent resource leaks + return + } + + // Case 2: Process incoming response chunks from the backend + // This handles the actual streaming data from the AI model + case chunk, okStream := <-respChan: + if !okStream { + flusher.Flush() + cliCancel() + return + } + + h.AddAPIResponseData(c, chunk) + h.AddAPIResponseData(c, []byte("\n")) + + // Convert the backend response to Claude-compatible format + // This translation layer ensures API compatibility + if bytes.HasPrefix(chunk, []byte("data: ")) { + jsonData := chunk[6:] + outputs := translatorClaudeCodeToQwen.ConvertOpenAIResponseToAnthropic(jsonData, params) + if len(outputs) > 0 { + for i := 0; i < len(outputs); i++ { + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write([]byte(outputs[i])) + } + } + flusher.Flush() // Immediately send the chunk to the client + // hasFirstResponse = true + } else { + // log.Debugf("chunk: %s", string(chunk)) + } + // Case 3: Handle errors from the backend + // This manages various error conditions and implements retry logic + case errInfo, okError := <-errChan: + if okError { + // log.Debugf("Code: %d, Error: %v", errInfo.StatusCode, errInfo.Error) + // Special handling for quota exceeded errors + // If configured, attempt to switch to a different project/client + if errInfo.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + log.Debugf("quota exceeded, switch client") + continue outLoop // Restart the client selection process + } else { + // Forward other errors directly to the client + c.Status(errInfo.StatusCode) + _, _ = fmt.Fprint(c.Writer, errInfo.Error.Error()) + flusher.Flush() + cliCancel(errInfo.Error) + } + return + } + + // Case 4: Send periodic keep-alive signals + // Prevents connection timeouts during long-running requests + case <-time.After(3000 * time.Millisecond): + } + } + } +} diff --git a/internal/api/handlers/gemini/cli/cli_handlers.go b/internal/api/handlers/gemini/cli/cli_handlers.go index c60992f8..55c7d38c 100644 --- a/internal/api/handlers/gemini/cli/cli_handlers.go +++ b/internal/api/handlers/gemini/cli/cli_handlers.go @@ -18,6 +18,7 @@ import ( "github.com/luispater/CLIProxyAPI/internal/client" translatorGeminiToClaude "github.com/luispater/CLIProxyAPI/internal/translator/claude/gemini" translatorGeminiToCodex "github.com/luispater/CLIProxyAPI/internal/translator/codex/gemini" + translatorGeminiToQwen "github.com/luispater/CLIProxyAPI/internal/translator/openai/gemini" "github.com/luispater/CLIProxyAPI/internal/util" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -64,6 +65,8 @@ func (h *GeminiCLIAPIHandlers) CLIHandler(c *gin.Context) { h.handleCodexInternalGenerateContent(c, rawJSON) } else if provider == "claude" { h.handleClaudeInternalGenerateContent(c, rawJSON) + } else if provider == "qwen" { + h.handleQwenInternalGenerateContent(c, rawJSON) } } else if requestRawURI == "/v1internal:streamGenerateContent" { if provider == "gemini" || provider == "unknow" { @@ -72,6 +75,8 @@ func (h *GeminiCLIAPIHandlers) CLIHandler(c *gin.Context) { h.handleCodexInternalStreamGenerateContent(c, rawJSON) } else if provider == "claude" { h.handleClaudeInternalStreamGenerateContent(c, rawJSON) + } else if provider == "qwen" { + h.handleQwenInternalStreamGenerateContent(c, rawJSON) } } else { reqBody := bytes.NewBuffer(rawJSON) @@ -733,3 +738,180 @@ outLoop: } } } + +func (h *GeminiCLIAPIHandlers) handleQwenInternalStreamGenerateContent(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + + // Get the http.Flusher interface to manually flush the response. + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + modelResult := gjson.GetBytes(rawJSON, "model") + rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw) + rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelResult.String()) + rawJSON, _ = sjson.SetRawBytes(rawJSON, "system_instruction", []byte(gjson.GetBytes(rawJSON, "systemInstruction").Raw)) + rawJSON, _ = sjson.DeleteBytes(rawJSON, "systemInstruction") + + // Prepare the request for the backend client. + newRequestJSON := translatorGeminiToQwen.ConvertGeminiRequestToOpenAI(rawJSON) + newRequestJSON, _ = sjson.Set(newRequestJSON, "stream", true) + + // log.Debugf("Request: %s", string(rawJSON)) + // return + + modelName := gjson.GetBytes(rawJSON, "model") + + cliCtx, cliCancel := h.GetContextWithCancel(c, context.Background()) + + var cliClient client.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName.String()) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error()) + flusher.Flush() + cliCancel() + return + } + + log.Debugf("Request qwen use account: %s", cliClient.(*client.QwenClient).GetEmail()) + + // Send the message and receive response chunks and errors via channels. + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "") + + params := &translatorGeminiToQwen.ConvertOpenAIResponseToGeminiParams{ + ToolCallsAccumulator: nil, + ContentAccumulator: strings.Builder{}, + IsFirstChunk: false, + } + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + cliCancel() + return + } + + h.AddAPIResponseData(c, chunk) + h.AddAPIResponseData(c, []byte("\n\n")) + + if bytes.HasPrefix(chunk, []byte("data: ")) { + jsonData := chunk[6:] + // log.Debugf(string(jsonData)) + outputs := translatorGeminiToQwen.ConvertOpenAIResponseToGemini(jsonData, params) + if len(outputs) > 0 { + for i := 0; i < len(outputs); i++ { + outputs[i], _ = sjson.SetRaw("{}", "response", outputs[i]) + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write([]byte(outputs[i])) + _, _ = c.Writer.Write([]byte("\n\n")) + } + } + // log.Debugf(string(jsonData)) + } + flusher.Flush() + // Handle errors from the backend. + case err, okError := <-errChan: + if okError { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue outLoop + } else { + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + flusher.Flush() + cliCancel(err.Error) + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + } + } + } +} + +func (h *GeminiCLIAPIHandlers) handleQwenInternalGenerateContent(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "application/json") + + modelResult := gjson.GetBytes(rawJSON, "model") + rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw) + rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelResult.String()) + rawJSON, _ = sjson.SetRawBytes(rawJSON, "system_instruction", []byte(gjson.GetBytes(rawJSON, "systemInstruction").Raw)) + rawJSON, _ = sjson.DeleteBytes(rawJSON, "systemInstruction") + + // Prepare the request for the backend client. + newRequestJSON := translatorGeminiToQwen.ConvertGeminiRequestToOpenAI(rawJSON) + // log.Debugf("Request: %s", newRequestJSON) + + modelName := gjson.GetBytes(rawJSON, "model") + + cliCtx, cliCancel := h.GetContextWithCancel(c, context.Background()) + + var cliClient client.Client + defer func() { + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName.String()) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error()) + cliCancel() + return + } + + log.Debugf("Request use qwen account: %s", cliClient.GetEmail()) + + resp, err := cliClient.SendRawMessage(cliCtx, []byte(newRequestJSON), "") + if err != nil { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue + } else { + c.Status(err.StatusCode) + _, _ = c.Writer.Write([]byte(err.Error.Error())) + cliCancel(err.Error) + } + break + } else { + h.AddAPIResponseData(c, resp) + h.AddAPIResponseData(c, []byte("\n")) + + newResp := translatorGeminiToQwen.ConvertOpenAINonStreamResponseToGemini(resp) + _, _ = c.Writer.Write([]byte(newResp)) + cliCancel(resp) + break + } + } +} diff --git a/internal/api/handlers/gemini/gemini_handlers.go b/internal/api/handlers/gemini/gemini_handlers.go index 890f42cf..f22d7077 100644 --- a/internal/api/handlers/gemini/gemini_handlers.go +++ b/internal/api/handlers/gemini/gemini_handlers.go @@ -19,6 +19,7 @@ import ( translatorGeminiToClaude "github.com/luispater/CLIProxyAPI/internal/translator/claude/gemini" translatorGeminiToCodex "github.com/luispater/CLIProxyAPI/internal/translator/codex/gemini" translatorGeminiToGeminiCli "github.com/luispater/CLIProxyAPI/internal/translator/gemini-cli/gemini/cli" + translatorGeminiToQwen "github.com/luispater/CLIProxyAPI/internal/translator/openai/gemini" "github.com/luispater/CLIProxyAPI/internal/util" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -241,6 +242,13 @@ func (h *GeminiAPIHandlers) GeminiHandler(c *gin.Context) { case "streamGenerateContent": h.handleClaudeStreamGenerateContent(c, rawJSON) } + } else if provider == "qwen" { + switch method { + case "generateContent": + h.handleQwenGenerateContent(c, rawJSON) + case "streamGenerateContent": + h.handleQwenStreamGenerateContent(c, rawJSON) + } } } @@ -961,3 +969,163 @@ outLoop: } } } + +func (h *GeminiAPIHandlers) handleQwenStreamGenerateContent(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + + // Get the http.Flusher interface to manually flush the response. + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + // Prepare the request for the backend client. + newRequestJSON := translatorGeminiToQwen.ConvertGeminiRequestToOpenAI(rawJSON) + newRequestJSON, _ = sjson.Set(newRequestJSON, "stream", true) + // log.Debugf("Request: %s", newRequestJSON) + + modelName := gjson.GetBytes(rawJSON, "model") + + cliCtx, cliCancel := h.GetContextWithCancel(c, context.Background()) + + var cliClient client.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName.String()) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error()) + flusher.Flush() + cliCancel() + return + } + + log.Debugf("Request use qwen account: %s", cliClient.GetEmail()) + + // Send the message and receive response chunks and errors via channels. + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "") + + params := &translatorGeminiToQwen.ConvertOpenAIResponseToGeminiParams{ + ToolCallsAccumulator: nil, + ContentAccumulator: strings.Builder{}, + IsFirstChunk: false, + } + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + cliCancel() + return + } + + h.AddAPIResponseData(c, chunk) + h.AddAPIResponseData(c, []byte("\n\n")) + if bytes.HasPrefix(chunk, []byte("data: ")) { + jsonData := chunk[6:] + outputs := translatorGeminiToQwen.ConvertOpenAIResponseToGemini(jsonData, params) + if len(outputs) > 0 { + for i := 0; i < len(outputs); i++ { + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write([]byte(outputs[i])) + _, _ = c.Writer.Write([]byte("\n\n")) + } + } + // log.Debugf(string(jsonData)) + } + flusher.Flush() + // Handle errors from the backend. + case err, okError := <-errChan: + if okError { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue outLoop + } else { + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + flusher.Flush() + cliCancel(err.Error) + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + } + } + } +} + +func (h *GeminiAPIHandlers) handleQwenGenerateContent(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "application/json") + + // Prepare the request for the backend client. + newRequestJSON := translatorGeminiToQwen.ConvertGeminiRequestToOpenAI(rawJSON) + // log.Debugf("Request: %s", newRequestJSON) + + modelName := gjson.GetBytes(rawJSON, "model") + + cliCtx, cliCancel := h.GetContextWithCancel(c, context.Background()) + + var cliClient client.Client + defer func() { + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName.String()) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error()) + cliCancel() + return + } + + log.Debugf("Request use qwen account: %s", cliClient.GetEmail()) + + resp, err := cliClient.SendRawMessage(cliCtx, []byte(newRequestJSON), "") + if err != nil { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue + } else { + c.Status(err.StatusCode) + _, _ = c.Writer.Write([]byte(err.Error.Error())) + cliCancel(err.Error) + } + break + } else { + h.AddAPIResponseData(c, resp) + h.AddAPIResponseData(c, []byte("\n")) + + newResp := translatorGeminiToQwen.ConvertOpenAINonStreamResponseToGemini(resp) + _, _ = c.Writer.Write([]byte(newResp)) + cliCancel(resp) + break + } + } +} diff --git a/internal/api/handlers/handlers.go b/internal/api/handlers/handlers.go index fd746b73..c6b7b5d5 100644 --- a/internal/api/handlers/handlers.go +++ b/internal/api/handlers/handlers.go @@ -118,6 +118,12 @@ func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (cl clients = append(clients, cli) } } + } else if provider == "qwen" { + for i := 0; i < len(h.CliClients); i++ { + if cli, ok := h.CliClients[i].(*client.QwenClient); ok { + clients = append(clients, cli) + } + } } if _, hasKey := h.LastUsedClientIndex[provider]; !hasKey { @@ -150,6 +156,8 @@ func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (cl log.Debugf("Codex Model %s is quota exceeded for account %s", modelName, cliClient.GetEmail()) } else if provider == "claude" { log.Debugf("Claude Model %s is quota exceeded for account %s", modelName, cliClient.GetEmail()) + } else if provider == "qwen" { + log.Debugf("Qwen Model %s is quota exceeded for account %s", modelName, cliClient.GetEmail()) } cliClient = nil continue diff --git a/internal/api/handlers/openai/openai_handlers.go b/internal/api/handlers/openai/openai_handlers.go index efd3810a..ae8eb965 100644 --- a/internal/api/handlers/openai/openai_handlers.go +++ b/internal/api/handlers/openai/openai_handlers.go @@ -171,6 +171,13 @@ func (h *OpenAIAPIHandlers) ChatCompletions(c *gin.Context) { } else { h.handleClaudeNonStreamingResponse(c, rawJSON) } + } else if provider == "qwen" { + // qwen3-coder-plus / qwen3-coder-flash + if streamResult.Type == gjson.True { + h.handleQwenStreamingResponse(c, rawJSON) + } else { + h.handleQwenNonStreamingResponse(c, rawJSON) + } } } @@ -761,3 +768,155 @@ outLoop: } } } + +// handleQwenNonStreamingResponse handles non-streaming chat completion responses +// for Qwen models. It selects a client from the pool, sends the request, and +// aggregates the response before sending it back to the client in OpenAI format. +// +// Parameters: +// - c: The Gin context containing the HTTP request and response +// - rawJSON: The raw JSON bytes of the OpenAI-compatible request +func (h *OpenAIAPIHandlers) handleQwenNonStreamingResponse(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "application/json") + + modelResult := gjson.GetBytes(rawJSON, "model") + modelName := modelResult.String() + cliCtx, cliCancel := h.GetContextWithCancel(c, context.Background()) + + var cliClient client.Client + defer func() { + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error()) + cliCancel() + return + } + + log.Debugf("Request qwen use account: %s", cliClient.(*client.QwenClient).GetEmail()) + + resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, modelName) + if err != nil { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue + } else { + c.Status(err.StatusCode) + _, _ = c.Writer.Write([]byte(err.Error.Error())) + cliCancel(err.Error) + } + break + } else { + _, _ = c.Writer.Write(resp) + cliCancel(resp) + break + } + } +} + +// handleQwenStreamingResponse handles streaming responses for Qwen models. +// It establishes a streaming connection with the backend service and forwards +// the response chunks to the client in real-time using Server-Sent Events. +// +// Parameters: +// - c: The Gin context containing the HTTP request and response +// - rawJSON: The raw JSON bytes of the OpenAI-compatible request +func (h *OpenAIAPIHandlers) handleQwenStreamingResponse(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + + // Get the http.Flusher interface to manually flush the response. + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + // Prepare the request for the backend client. + modelResult := gjson.GetBytes(rawJSON, "model") + modelName := modelResult.String() + + cliCtx, cliCancel := h.GetContextWithCancel(c, context.Background()) + + var cliClient client.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error()) + flusher.Flush() + cliCancel() + return + } + + log.Debugf("Request qwen use account: %s", cliClient.(*client.QwenClient).GetEmail()) + + // Send the message and receive response chunks and errors via channels. + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, modelName) + + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("GeminiClient disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + flusher.Flush() + cliCancel() + return + } + + h.AddAPIResponseData(c, chunk) + h.AddAPIResponseData(c, []byte("\n")) + + // Convert the chunk to OpenAI format and send it to the client. + _, _ = c.Writer.Write(chunk) + _, _ = c.Writer.Write([]byte("\n")) + + flusher.Flush() + // Handle errors from the backend. + case err, okError := <-errChan: + if okError { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { + continue outLoop + } else { + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + flusher.Flush() + cliCancel(err.Error) + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + } + } + } +} diff --git a/internal/api/server.go b/internal/api/server.go index a4b3fd5a..9d2791e1 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -66,7 +66,7 @@ func NewServer(cfg *config.Config, cliClients []client.Client) *Server { requestLogger := logging.NewFileRequestLogger(cfg.RequestLog, "logs") engine.Use(middleware.RequestLoggingMiddleware(requestLogger)) - // engine.Use(corsMiddleware()) + engine.Use(corsMiddleware()) // Create server instance s := &Server{ diff --git a/internal/auth/qwen/qwen_auth.go b/internal/auth/qwen/qwen_auth.go new file mode 100644 index 00000000..e3989f63 --- /dev/null +++ b/internal/auth/qwen/qwen_auth.go @@ -0,0 +1,337 @@ +package qwen + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/luispater/CLIProxyAPI/internal/config" + "github.com/luispater/CLIProxyAPI/internal/util" + log "github.com/sirupsen/logrus" +) + +const ( + // OAuth Configuration + QwenOAuthDeviceCodeEndpoint = "https://chat.qwen.ai/api/v1/oauth2/device/code" + QwenOAuthTokenEndpoint = "https://chat.qwen.ai/api/v1/oauth2/token" + QwenOAuthClientID = "f0304373b74a44d2b584a3fb70ca9e56" + QwenOAuthScope = "openid profile email model.completion" + QwenOAuthGrantType = "urn:ietf:params:oauth:grant-type:device_code" +) + +// QwenTokenData represents OAuth credentials +type QwenTokenData struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + TokenType string `json:"token_type"` + ResourceURL string `json:"resource_url,omitempty"` + Expire string `json:"expiry_date,omitempty"` +} + +// DeviceFlow represents device flow response +type DeviceFlow struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` + CodeVerifier string `json:"code_verifier"` +} + +// QwenTokenResponse represents token response +type QwenTokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + TokenType string `json:"token_type"` + ResourceURL string `json:"resource_url,omitempty"` + ExpiresIn int `json:"expires_in"` +} + +// QwenAuth manages authentication and credentials +type QwenAuth struct { + httpClient *http.Client +} + +// NewQwenAuth creates a new QwenAuth +func NewQwenAuth(cfg *config.Config) *QwenAuth { + return &QwenAuth{ + httpClient: util.SetProxy(cfg, &http.Client{}), + } +} + +// generateCodeVerifier generates a random code verifier for PKCE +func (qa *QwenAuth) generateCodeVerifier() (string, error) { + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(bytes), nil +} + +// generateCodeChallenge generates a code challenge from a code verifier using SHA-256 +func (qa *QwenAuth) generateCodeChallenge(codeVerifier string) string { + hash := sha256.Sum256([]byte(codeVerifier)) + return base64.RawURLEncoding.EncodeToString(hash[:]) +} + +// generatePKCEPair generates PKCE code verifier and challenge pair +func (qa *QwenAuth) generatePKCEPair() (string, string, error) { + codeVerifier, err := qa.generateCodeVerifier() + if err != nil { + return "", "", err + } + codeChallenge := qa.generateCodeChallenge(codeVerifier) + return codeVerifier, codeChallenge, nil +} + +// RefreshTokens refreshes the access token using refresh token +func (qa *QwenAuth) RefreshTokens(ctx context.Context, refreshToken string) (*QwenTokenData, error) { + data := url.Values{} + data.Set("grant_type", "refresh_token") + data.Set("refresh_token", refreshToken) + data.Set("client_id", QwenOAuthClientID) + + req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthTokenEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("failed to create token request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := qa.httpClient.Do(req) + + // resp, err := qa.httpClient.PostForm(QwenOAuthTokenEndpoint, data) + if err != nil { + return nil, fmt.Errorf("token refresh request failed: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + var errorData map[string]interface{} + if err = json.Unmarshal(body, &errorData); err == nil { + return nil, fmt.Errorf("token refresh failed: %v - %v", errorData["error"], errorData["error_description"]) + } + return nil, fmt.Errorf("token refresh failed: %s", string(body)) + } + + var tokenData QwenTokenResponse + if err = json.Unmarshal(body, &tokenData); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + return &QwenTokenData{ + AccessToken: tokenData.AccessToken, + TokenType: tokenData.TokenType, + RefreshToken: tokenData.RefreshToken, + ResourceURL: tokenData.ResourceURL, + Expire: time.Now().Add(time.Duration(tokenData.ExpiresIn) * time.Second).Format(time.RFC3339), + }, nil +} + +// InitiateDeviceFlow initiates the OAuth device flow +func (qa *QwenAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceFlow, error) { + // Generate PKCE code verifier and challenge + codeVerifier, codeChallenge, err := qa.generatePKCEPair() + if err != nil { + return nil, fmt.Errorf("failed to generate PKCE pair: %w", err) + } + + data := url.Values{} + data.Set("client_id", QwenOAuthClientID) + data.Set("scope", QwenOAuthScope) + data.Set("code_challenge", codeChallenge) + data.Set("code_challenge_method", "S256") + + req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthDeviceCodeEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("failed to create token request: %w", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Accept", "application/json") + + resp, err := qa.httpClient.Do(req) + + // resp, err := qa.httpClient.PostForm(QwenOAuthDeviceCodeEndpoint, data) + if err != nil { + return nil, fmt.Errorf("device authorization request failed: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("device authorization failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body)) + } + + var result DeviceFlow + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse device flow response: %w", err) + } + + // Check if the response indicates success + if result.DeviceCode == "" { + return nil, fmt.Errorf("device authorization failed: device_code not found in response") + } + + // Add the code_verifier to the result so it can be used later for polling + result.CodeVerifier = codeVerifier + + return &result, nil +} + +// PollForToken polls for the access token using device code +func (qa *QwenAuth) PollForToken(deviceCode, codeVerifier string) (*QwenTokenData, error) { + pollInterval := 5 * time.Second + maxAttempts := 60 // 5 minutes max + + for attempt := 0; attempt < maxAttempts; attempt++ { + data := url.Values{} + data.Set("grant_type", QwenOAuthGrantType) + data.Set("client_id", QwenOAuthClientID) + data.Set("device_code", deviceCode) + data.Set("code_verifier", codeVerifier) + + resp, err := http.PostForm(QwenOAuthTokenEndpoint, data) + if err != nil { + fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err) + time.Sleep(pollInterval) + continue + } + + body, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err) + time.Sleep(pollInterval) + continue + } + + if resp.StatusCode != http.StatusOK { + // Parse the response as JSON to check for OAuth RFC 8628 standard errors + var errorData map[string]interface{} + if err = json.Unmarshal(body, &errorData); err == nil { + // According to OAuth RFC 8628, handle standard polling responses + if resp.StatusCode == http.StatusBadRequest { + errorType, _ := errorData["error"].(string) + switch errorType { + case "authorization_pending": + // User has not yet approved the authorization request. Continue polling. + log.Infof("Polling attempt %d/%d...\n", attempt+1, maxAttempts) + time.Sleep(pollInterval) + continue + case "slow_down": + // Client is polling too frequently. Increase poll interval. + pollInterval = time.Duration(float64(pollInterval) * 1.5) + if pollInterval > 10*time.Second { + pollInterval = 10 * time.Second + } + log.Infof("Server requested to slow down, increasing poll interval to %v\n", pollInterval) + time.Sleep(pollInterval) + continue + case "expired_token": + return nil, fmt.Errorf("device code expired. Please restart the authentication process") + case "access_denied": + return nil, fmt.Errorf("authorization denied by user. Please restart the authentication process") + } + } + + // For other errors, return with proper error information + errorType, _ := errorData["error"].(string) + errorDesc, _ := errorData["error_description"].(string) + return nil, fmt.Errorf("device token poll failed: %s - %s", errorType, errorDesc) + } + + // If JSON parsing fails, fall back to text response + return nil, fmt.Errorf("device token poll failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body)) + } + log.Debugf(string(body)) + // Success - parse token data + var response QwenTokenResponse + if err = json.Unmarshal(body, &response); err != nil { + return nil, fmt.Errorf("failed to parse token response: %w", err) + } + + // Convert to QwenTokenData format and save + tokenData := &QwenTokenData{ + AccessToken: response.AccessToken, + RefreshToken: response.RefreshToken, + TokenType: response.TokenType, + ResourceURL: response.ResourceURL, + Expire: time.Now().Add(time.Duration(response.ExpiresIn) * time.Second).Format(time.RFC3339), + } + + return tokenData, nil + } + + return nil, fmt.Errorf("authentication timeout. Please restart the authentication process") +} + +// RefreshTokensWithRetry refreshes tokens with automatic retry logic +func (o *QwenAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*QwenTokenData, error) { + var lastErr error + + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + // Wait before retry + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(attempt) * time.Second): + } + } + + tokenData, err := o.RefreshTokens(ctx, refreshToken) + if err == nil { + return tokenData, nil + } + + lastErr = err + log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err) + } + + return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr) +} + +func (o *QwenAuth) CreateTokenStorage(tokenData *QwenTokenData) *QwenTokenStorage { + storage := &QwenTokenStorage{ + AccessToken: tokenData.AccessToken, + RefreshToken: tokenData.RefreshToken, + LastRefresh: time.Now().Format(time.RFC3339), + ResourceURL: tokenData.ResourceURL, + Expire: tokenData.Expire, + } + + return storage +} + +// UpdateTokenStorage updates an existing token storage with new token data +func (o *QwenAuth) UpdateTokenStorage(storage *QwenTokenStorage, tokenData *QwenTokenData) { + storage.AccessToken = tokenData.AccessToken + storage.RefreshToken = tokenData.RefreshToken + storage.LastRefresh = time.Now().Format(time.RFC3339) + storage.ResourceURL = tokenData.ResourceURL + storage.Expire = tokenData.Expire +} diff --git a/internal/auth/qwen/qwen_token.go b/internal/auth/qwen/qwen_token.go new file mode 100644 index 00000000..733911cb --- /dev/null +++ b/internal/auth/qwen/qwen_token.go @@ -0,0 +1,61 @@ +// Package gemini provides authentication and token management functionality +// for Google's Gemini AI services. It handles OAuth2 token storage, serialization, +// and retrieval for maintaining authenticated sessions with the Gemini API. +package qwen + +import ( + "encoding/json" + "fmt" + "os" + "path" +) + +// QwenTokenStorage defines the structure for storing OAuth2 token information, +// along with associated user and project details. This data is typically +// serialized to a JSON file for persistence. +type QwenTokenStorage struct { + // AccessToken is the OAuth2 access token for API access + AccessToken string `json:"access_token"` + // RefreshToken is used to obtain new access tokens + RefreshToken string `json:"refresh_token"` + // LastRefresh is the timestamp of the last token refresh + LastRefresh string `json:"last_refresh"` + // ResourceURL is the request base url + ResourceURL string `json:"resource_url"` + // Email is the OpenAI account email + Email string `json:"email"` + // Type indicates the type (gemini, chatgpt, claude) of token storage. + Type string `json:"type"` + // Expire is the timestamp of the token expire + Expire string `json:"expired"` +} + +// SaveTokenToFile serializes the token storage to a JSON file. +// This method creates the necessary directory structure and writes the token +// data in JSON format to the specified file path. It ensures the file is +// properly closed after writing. +// +// Parameters: +// - authFilePath: The full path where the token file should be saved +// +// Returns: +// - error: An error if the operation fails, nil otherwise +func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error { + ts.Type = "qwen" + if err := os.MkdirAll(path.Dir(authFilePath), 0700); err != nil { + return fmt.Errorf("failed to create directory: %v", err) + } + + f, err := os.Create(authFilePath) + if err != nil { + return fmt.Errorf("failed to create token file: %w", err) + } + defer func() { + _ = f.Close() + }() + + if err = json.NewEncoder(f).Encode(ts); err != nil { + return fmt.Errorf("failed to write token to file: %w", err) + } + return nil +} diff --git a/internal/client/qwen_client.go b/internal/client/qwen_client.go new file mode 100644 index 00000000..491ff117 --- /dev/null +++ b/internal/client/qwen_client.go @@ -0,0 +1,288 @@ +package client + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/luispater/CLIProxyAPI/internal/auth" + "github.com/luispater/CLIProxyAPI/internal/auth/qwen" + "github.com/luispater/CLIProxyAPI/internal/config" + "github.com/luispater/CLIProxyAPI/internal/util" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + qwenEndpoint = "https://portal.qwen.ai/v1" +) + +// QwenClient implements the Client interface for OpenAI API +type QwenClient struct { + ClientBase + qwenAuth *qwen.QwenAuth +} + +// NewQwenClient creates a new OpenAI client instance +func NewQwenClient(cfg *config.Config, ts *qwen.QwenTokenStorage) *QwenClient { + httpClient := util.SetProxy(cfg, &http.Client{}) + client := &QwenClient{ + ClientBase: ClientBase{ + RequestMutex: &sync.Mutex{}, + httpClient: httpClient, + cfg: cfg, + modelQuotaExceeded: make(map[string]*time.Time), + tokenStorage: ts, + }, + qwenAuth: qwen.NewQwenAuth(cfg), + } + + return client +} + +// GetUserAgent returns the user agent string for OpenAI API requests +func (c *QwenClient) GetUserAgent() string { + return "google-api-nodejs-client/9.15.1" +} + +func (c *QwenClient) TokenStorage() auth.TokenStorage { + return c.tokenStorage +} + +// SendMessage sends a message to OpenAI API (non-streaming) +func (c *QwenClient) SendMessage(_ context.Context, _ []byte, _ string, _ *Content, _ []Content, _ []ToolDeclaration) ([]byte, *ErrorMessage) { + // For now, return an error as OpenAI integration is not fully implemented + return nil, &ErrorMessage{ + StatusCode: http.StatusNotImplemented, + Error: fmt.Errorf("qwen message sending not yet implemented"), + } +} + +// SendMessageStream sends a streaming message to OpenAI API +func (c *QwenClient) SendMessageStream(_ context.Context, _ []byte, _ string, _ *Content, _ []Content, _ []ToolDeclaration, _ ...bool) (<-chan []byte, <-chan *ErrorMessage) { + errChan := make(chan *ErrorMessage, 1) + errChan <- &ErrorMessage{ + StatusCode: http.StatusNotImplemented, + Error: fmt.Errorf("qwen streaming not yet implemented"), + } + close(errChan) + + return nil, errChan +} + +// SendRawMessage sends a raw message to OpenAI API +func (c *QwenClient) SendRawMessage(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) { + modelResult := gjson.GetBytes(rawJSON, "model") + model := modelResult.String() + modelName := model + + respBody, err := c.APIRequest(ctx, "/chat/completions", rawJSON, alt, false) + if err != nil { + if err.StatusCode == 429 { + now := time.Now() + c.modelQuotaExceeded[modelName] = &now + } + return nil, err + } + delete(c.modelQuotaExceeded, modelName) + bodyBytes, errReadAll := io.ReadAll(respBody) + if errReadAll != nil { + return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll} + } + return bodyBytes, nil + +} + +// SendRawMessageStream sends a raw streaming message to OpenAI API +func (c *QwenClient) SendRawMessageStream(ctx context.Context, rawJSON []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) { + errChan := make(chan *ErrorMessage) + dataChan := make(chan []byte) + go func() { + defer close(errChan) + defer close(dataChan) + + modelResult := gjson.GetBytes(rawJSON, "model") + model := modelResult.String() + modelName := model + var stream io.ReadCloser + for { + var err *ErrorMessage + stream, err = c.APIRequest(ctx, "/chat/completions", rawJSON, alt, true) + if err != nil { + if err.StatusCode == 429 { + now := time.Now() + c.modelQuotaExceeded[modelName] = &now + } + errChan <- err + return + } + delete(c.modelQuotaExceeded, modelName) + break + } + + scanner := bufio.NewScanner(stream) + buffer := make([]byte, 10240*1024) + scanner.Buffer(buffer, 10240*1024) + for scanner.Scan() { + line := scanner.Bytes() + dataChan <- line + } + + if errScanner := scanner.Err(); errScanner != nil { + errChan <- &ErrorMessage{500, errScanner, nil} + _ = stream.Close() + return + } + + _ = stream.Close() + }() + + return dataChan, errChan +} + +// SendRawTokenCount sends a token count request to OpenAI API +func (c *QwenClient) SendRawTokenCount(_ context.Context, _ []byte, _ string) ([]byte, *ErrorMessage) { + return nil, &ErrorMessage{ + StatusCode: http.StatusNotImplemented, + Error: fmt.Errorf("qwen token counting not yet implemented"), + } +} + +// SaveTokenToFile persists the token storage to disk +func (c *QwenClient) SaveTokenToFile() error { + fileName := filepath.Join(c.cfg.AuthDir, fmt.Sprintf("qwen-%s.json", c.tokenStorage.(*qwen.QwenTokenStorage).Email)) + return c.tokenStorage.SaveTokenToFile(fileName) +} + +// RefreshTokens refreshes the access tokens if needed +func (c *QwenClient) RefreshTokens(ctx context.Context) error { + if c.tokenStorage == nil || c.tokenStorage.(*qwen.QwenTokenStorage).RefreshToken == "" { + return fmt.Errorf("no refresh token available") + } + + // Refresh tokens using the auth service + newTokenData, err := c.qwenAuth.RefreshTokensWithRetry(ctx, c.tokenStorage.(*qwen.QwenTokenStorage).RefreshToken, 3) + if err != nil { + return fmt.Errorf("failed to refresh tokens: %w", err) + } + + // Update token storage + c.qwenAuth.UpdateTokenStorage(c.tokenStorage.(*qwen.QwenTokenStorage), newTokenData) + + // Save updated tokens + if err = c.SaveTokenToFile(); err != nil { + log.Warnf("Failed to save refreshed tokens: %v", err) + } + + log.Debug("qwen tokens refreshed successfully") + return nil +} + +// APIRequest handles making requests to the CLI API endpoints. +func (c *QwenClient) APIRequest(ctx context.Context, endpoint string, body interface{}, _ string, _ bool) (io.ReadCloser, *ErrorMessage) { + var jsonBody []byte + var err error + if byteBody, ok := body.([]byte); ok { + jsonBody = byteBody + } else { + jsonBody, err = json.Marshal(body) + if err != nil { + return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err), nil} + } + } + + streamResult := gjson.GetBytes(jsonBody, "stream") + if streamResult.Exists() && streamResult.Type == gjson.True { + jsonBody, _ = sjson.SetBytes(jsonBody, "stream_options.include_usage", true) + } + + var url string + if c.tokenStorage.(*qwen.QwenTokenStorage).ResourceURL == "" { + url = fmt.Sprintf("https://%s/v1%s", c.tokenStorage.(*qwen.QwenTokenStorage).ResourceURL, endpoint) + } else { + url = fmt.Sprintf("%s%s", qwenEndpoint, endpoint) + } + + // log.Debug(string(jsonBody)) + // log.Debug(url) + reqBody := bytes.NewBuffer(jsonBody) + + req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody) + if err != nil { + return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %v", err), nil} + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", c.GetUserAgent()) + req.Header.Set("X-Goog-Api-Client", "gl-node/22.17.0") + req.Header.Set("Client-Metadata", c.getClientMetadataString()) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.tokenStorage.(*qwen.QwenTokenStorage).AccessToken)) + + if ginContext, ok := ctx.Value("gin").(*gin.Context); ok { + ginContext.Set("API_REQUEST", jsonBody) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %v", err), nil} + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + defer func() { + if err = resp.Body.Close(); err != nil { + log.Printf("warn: failed to close response body: %v", err) + } + }() + bodyBytes, _ := io.ReadAll(resp.Body) + // log.Debug(string(jsonBody)) + return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes)), nil} + } + + return resp.Body, nil +} + +func (c *QwenClient) getClientMetadata() map[string]string { + return map[string]string{ + "ideType": "IDE_UNSPECIFIED", + "platform": "PLATFORM_UNSPECIFIED", + "pluginType": "GEMINI", + // "pluginVersion": pluginVersion, + } +} + +func (c *QwenClient) getClientMetadataString() string { + md := c.getClientMetadata() + parts := make([]string, 0, len(md)) + for k, v := range md { + parts = append(parts, fmt.Sprintf("%s=%s", k, v)) + } + return strings.Join(parts, ",") +} + +func (c *QwenClient) GetEmail() string { + return c.tokenStorage.(*qwen.QwenTokenStorage).Email +} + +// IsModelQuotaExceeded returns true if the specified model has exceeded its quota +// and no fallback options are available. +func (c *QwenClient) IsModelQuotaExceeded(model string) bool { + if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey { + duration := time.Now().Sub(*lastExceededTime) + if duration > 30*time.Minute { + return false + } + return true + } + return false +} diff --git a/internal/cmd/qwen_login.go b/internal/cmd/qwen_login.go new file mode 100644 index 00000000..953d29a0 --- /dev/null +++ b/internal/cmd/qwen_login.go @@ -0,0 +1,85 @@ +package cmd + +import ( + "context" + "fmt" + "os" + + "github.com/luispater/CLIProxyAPI/internal/auth/qwen" + "github.com/luispater/CLIProxyAPI/internal/browser" + "github.com/luispater/CLIProxyAPI/internal/client" + "github.com/luispater/CLIProxyAPI/internal/config" + log "github.com/sirupsen/logrus" +) + +// DoQwenLogin handles the Qwen OAuth login process +func DoQwenLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + ctx := context.Background() + + log.Info("Initializing Qwen authentication...") + + // Initialize Qwen auth service + qwenAuth := qwen.NewQwenAuth(cfg) + + // Generate authorization URL + deviceFlow, err := qwenAuth.InitiateDeviceFlow(ctx) + if err != nil { + log.Fatalf("Failed to generate authorization URL: %v", err) + return + } + authURL := deviceFlow.VerificationURIComplete + + // Open browser or display URL + if !options.NoBrowser { + log.Info("Opening browser for authentication...") + + // Check if browser is available + if !browser.IsAvailable() { + log.Warn("No browser available on this system") + log.Infof("Please manually open this URL in your browser:\n\n%s\n", authURL) + } else { + if err = browser.OpenURL(authURL); err != nil { + log.Infof("Please manually open this URL in your browser:\n\n%s\n", authURL) + + // Log platform info for debugging + platformInfo := browser.GetPlatformInfo() + log.Debugf("Browser platform info: %+v", platformInfo) + } else { + log.Debug("Browser opened successfully") + } + } + } else { + log.Infof("Please open this URL in your browser:\n\n%s\n", authURL) + } + + log.Info("Waiting for authentication...") + tokenData, err := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier) + if err != nil { + fmt.Printf("Authentication failed: %v\n", err) + os.Exit(1) + } + + // Create token storage + tokenStorage := qwenAuth.CreateTokenStorage(tokenData) + + // Initialize Qwen client + qwenClient := client.NewQwenClient(cfg, tokenStorage) + + fmt.Println("\nPlease input your email address or any alias:") + var email string + _, _ = fmt.Scanln(&email) + tokenStorage.Email = email + + // Save token storage + if err = qwenClient.SaveTokenToFile(); err != nil { + log.Fatalf("Failed to save authentication tokens: %v", err) + return + } + + log.Info("Authentication successful!") + log.Info("You can now use Qwen services through this CLI") +} diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 87a93246..63823d44 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -22,6 +22,7 @@ import ( "github.com/luispater/CLIProxyAPI/internal/auth/claude" "github.com/luispater/CLIProxyAPI/internal/auth/codex" "github.com/luispater/CLIProxyAPI/internal/auth/gemini" + "github.com/luispater/CLIProxyAPI/internal/auth/qwen" "github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/config" "github.com/luispater/CLIProxyAPI/internal/util" @@ -102,6 +103,15 @@ func StartService(cfg *config.Config, configPath string) { log.Info("Authentication successful.") cliClients = append(cliClients, claudeClient) } + } else if tokenType == "qwen" { + var ts qwen.QwenTokenStorage + if err = json.Unmarshal(data, &ts); err == nil { + // For each valid token, create an authenticated client. + log.Info("Initializing qwen authentication for token...") + qwenClient := client.NewQwenClient(cfg, &ts) + log.Info("Authentication successful.") + cliClients = append(cliClients, qwenClient) + } } } return nil @@ -200,12 +210,23 @@ func StartService(cfg *config.Config, configPath string) { if ts != nil && ts.Expire != "" { if expTime, errParse := time.Parse(time.RFC3339, ts.Expire); errParse == nil { if time.Until(expTime) <= 4*time.Hour { - log.Debugf("refreshing codex tokens for %s", claudeCli.GetEmail()) + log.Debugf("refreshing claude tokens for %s", claudeCli.GetEmail()) _ = claudeCli.RefreshTokens(ctxRefresh) } } } } + } else if qwenCli, isQwenOK := cliClients[i].(*client.QwenClient); isQwenOK { + if ts, isQwenTS := qwenCli.TokenStorage().(*qwen.QwenTokenStorage); isQwenTS { + if ts != nil && ts.Expire != "" { + if expTime, errParse := time.Parse(time.RFC3339, ts.Expire); errParse == nil { + if time.Until(expTime) <= 3*time.Hour { + log.Debugf("refreshing qwen tokens for %s", qwenCli.GetEmail()) + _ = qwenCli.RefreshTokens(ctxRefresh) + } + } + } + } } } } diff --git a/internal/translator/openai/claude/openai_claude_request.go b/internal/translator/openai/claude/openai_claude_request.go new file mode 100644 index 00000000..9937725f --- /dev/null +++ b/internal/translator/openai/claude/openai_claude_request.go @@ -0,0 +1,253 @@ +// Package claude provides request translation functionality for Anthropic to OpenAI API. +// It handles parsing and transforming Anthropic API requests into OpenAI Chat Completions API format, +// extracting model information, system instructions, message contents, and tool declarations. +// The package performs JSON data transformation to ensure compatibility +// between Anthropic API format and OpenAI API's expected format. +package claude + +import ( + "encoding/json" + "strings" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// ConvertAnthropicRequestToOpenAI parses and transforms an Anthropic API request into OpenAI Chat Completions API format. +// It extracts the model name, system instruction, message contents, and tool declarations +// from the raw JSON request and returns them in the format expected by the OpenAI API. +func ConvertAnthropicRequestToOpenAI(rawJSON []byte) string { + // Base OpenAI Chat Completions API template + out := `{"model":"","messages":[]}` + + root := gjson.ParseBytes(rawJSON) + + // Model mapping + if model := root.Get("model"); model.Exists() { + modelStr := model.String() + out, _ = sjson.Set(out, "model", modelStr) + } + + // Max tokens + if maxTokens := root.Get("max_tokens"); maxTokens.Exists() { + out, _ = sjson.Set(out, "max_tokens", maxTokens.Int()) + } + + // Temperature + if temp := root.Get("temperature"); temp.Exists() { + out, _ = sjson.Set(out, "temperature", temp.Float()) + } + + // Top P + if topP := root.Get("top_p"); topP.Exists() { + out, _ = sjson.Set(out, "top_p", topP.Float()) + } + + // Stop sequences -> stop + if stopSequences := root.Get("stop_sequences"); stopSequences.Exists() { + if stopSequences.IsArray() { + var stops []string + stopSequences.ForEach(func(_, value gjson.Result) bool { + stops = append(stops, value.String()) + return true + }) + if len(stops) > 0 { + if len(stops) == 1 { + out, _ = sjson.Set(out, "stop", stops[0]) + } else { + out, _ = sjson.Set(out, "stop", stops) + } + } + } + } + + // Stream + if stream := root.Get("stream"); stream.Exists() { + out, _ = sjson.Set(out, "stream", stream.Bool()) + } + + // Process messages and system + var openAIMessages []interface{} + + // Handle system message first + if system := root.Get("system"); system.Exists() && system.String() != "" { + systemMsg := map[string]interface{}{ + "role": "system", + "content": system.String(), + } + openAIMessages = append(openAIMessages, systemMsg) + } + + // Process Anthropic messages + if messages := root.Get("messages"); messages.Exists() && messages.IsArray() { + messages.ForEach(func(_, message gjson.Result) bool { + role := message.Get("role").String() + contentResult := message.Get("content") + + msg := map[string]interface{}{ + "role": role, + } + + // Handle content + if contentResult.Exists() && contentResult.IsArray() { + var textParts []string + var toolCalls []interface{} + var toolResults []interface{} + + contentResult.ForEach(func(_, part gjson.Result) bool { + partType := part.Get("type").String() + + switch partType { + case "text": + textParts = append(textParts, part.Get("text").String()) + + case "image": + // Convert Anthropic image format to OpenAI format + if source := part.Get("source"); source.Exists() { + sourceType := source.Get("type").String() + if sourceType == "base64" { + mediaType := source.Get("media_type").String() + data := source.Get("data").String() + imageURL := "data:" + mediaType + ";base64," + data + + // For now, add as text since OpenAI image handling is complex + // In a real implementation, you'd need to handle this properly + textParts = append(textParts, "[Image: "+imageURL+"]") + } + } + + case "tool_use": + // Convert to OpenAI tool call format + toolCall := map[string]interface{}{ + "id": part.Get("id").String(), + "type": "function", + "function": map[string]interface{}{ + "name": part.Get("name").String(), + }, + } + + // Convert input to arguments JSON string + if input := part.Get("input"); input.Exists() { + if inputJSON, err := json.Marshal(input.Value()); err == nil { + if function, ok := toolCall["function"].(map[string]interface{}); ok { + function["arguments"] = string(inputJSON) + } + } else { + if function, ok := toolCall["function"].(map[string]interface{}); ok { + function["arguments"] = "{}" + } + } + } else { + if function, ok := toolCall["function"].(map[string]interface{}); ok { + function["arguments"] = "{}" + } + } + + toolCalls = append(toolCalls, toolCall) + + case "tool_result": + // Convert to OpenAI tool message format + toolResult := map[string]interface{}{ + "role": "tool", + "tool_call_id": part.Get("tool_use_id").String(), + "content": part.Get("content").String(), + } + toolResults = append(toolResults, toolResult) + } + return true + }) + + // Set content + if len(textParts) > 0 { + msg["content"] = strings.Join(textParts, "") + } else { + msg["content"] = "" + } + + // Set tool calls for assistant messages + if role == "assistant" && len(toolCalls) > 0 { + msg["tool_calls"] = toolCalls + } + + openAIMessages = append(openAIMessages, msg) + + // Add tool result messages separately + for _, toolResult := range toolResults { + openAIMessages = append(openAIMessages, toolResult) + } + + } else if contentResult.Exists() && contentResult.Type == gjson.String { + // Simple string content + msg["content"] = contentResult.String() + openAIMessages = append(openAIMessages, msg) + } + + return true + }) + } + + // Set messages + if len(openAIMessages) > 0 { + messagesJSON, _ := json.Marshal(openAIMessages) + out, _ = sjson.SetRaw(out, "messages", string(messagesJSON)) + } + + // Process tools - convert Anthropic tools to OpenAI functions + if tools := root.Get("tools"); tools.Exists() && tools.IsArray() { + var openAITools []interface{} + + tools.ForEach(func(_, tool gjson.Result) bool { + openAITool := map[string]interface{}{ + "type": "function", + "function": map[string]interface{}{ + "name": tool.Get("name").String(), + "description": tool.Get("description").String(), + }, + } + + // Convert Anthropic input_schema to OpenAI function parameters + if inputSchema := tool.Get("input_schema"); inputSchema.Exists() { + if function, ok := openAITool["function"].(map[string]interface{}); ok { + function["parameters"] = inputSchema.Value() + } + } + + openAITools = append(openAITools, openAITool) + return true + }) + + if len(openAITools) > 0 { + toolsJSON, _ := json.Marshal(openAITools) + out, _ = sjson.SetRaw(out, "tools", string(toolsJSON)) + } + } + + // Tool choice mapping - convert Anthropic tool_choice to OpenAI format + if toolChoice := root.Get("tool_choice"); toolChoice.Exists() { + switch toolChoice.Get("type").String() { + case "auto": + out, _ = sjson.Set(out, "tool_choice", "auto") + case "any": + out, _ = sjson.Set(out, "tool_choice", "required") + case "tool": + // Specific tool choice + toolName := toolChoice.Get("name").String() + out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{ + "type": "function", + "function": map[string]interface{}{ + "name": toolName, + }, + }) + default: + // Default to auto if not specified + out, _ = sjson.Set(out, "tool_choice", "auto") + } + } + + // Handle user parameter (for tracking) + if user := root.Get("user"); user.Exists() { + out, _ = sjson.Set(out, "user", user.String()) + } + + return out +} diff --git a/internal/translator/openai/claude/openai_claude_response.go b/internal/translator/openai/claude/openai_claude_response.go new file mode 100644 index 00000000..a636e484 --- /dev/null +++ b/internal/translator/openai/claude/openai_claude_response.go @@ -0,0 +1,389 @@ +// Package claude provides response translation functionality for OpenAI to Anthropic API. +// This package handles the conversion of OpenAI Chat Completions API responses into Anthropic API-compatible +// JSON format, transforming streaming events and non-streaming responses into the format +// expected by Anthropic API clients. It supports both streaming and non-streaming modes, +// handling text content, tool calls, and usage metadata appropriately. +package claude + +import ( + "encoding/json" + "strings" + + "github.com/tidwall/gjson" +) + +// ConvertOpenAIResponseToAnthropicParams holds parameters for response conversion +type ConvertOpenAIResponseToAnthropicParams struct { + MessageID string + Model string + CreatedAt int64 + // Content accumulator for streaming + ContentAccumulator strings.Builder + // Tool calls accumulator for streaming + ToolCallsAccumulator map[int]*ToolCallAccumulator + // Track if text content block has been started + TextContentBlockStarted bool + // Track finish reason for later use + FinishReason string + // Track if content blocks have been stopped + ContentBlocksStopped bool + // Track if message_delta has been sent + MessageDeltaSent bool +} + +// ToolCallAccumulator holds the state for accumulating tool call data +type ToolCallAccumulator struct { + ID string + Name string + Arguments strings.Builder +} + +// ConvertOpenAIResponseToAnthropic converts OpenAI streaming response format to Anthropic API format. +// This function processes OpenAI streaming chunks and transforms them into Anthropic-compatible JSON responses. +// It handles text content, tool calls, and usage metadata, outputting responses that match the Anthropic API format. +func ConvertOpenAIResponseToAnthropic(rawJSON []byte, param *ConvertOpenAIResponseToAnthropicParams) []string { + // Check if this is the [DONE] marker + rawStr := strings.TrimSpace(string(rawJSON)) + if rawStr == "[DONE]" { + return convertOpenAIDoneToAnthropic(param) + } + + root := gjson.ParseBytes(rawJSON) + + // Check if this is a streaming chunk or non-streaming response + objectType := root.Get("object").String() + + if objectType == "chat.completion.chunk" { + // Handle streaming response + return convertOpenAIStreamingChunkToAnthropic(rawJSON, param) + } else if objectType == "chat.completion" { + // Handle non-streaming response + return convertOpenAINonStreamingToAnthropic(rawJSON) + } + + return []string{} +} + +// convertOpenAIStreamingChunkToAnthropic converts OpenAI streaming chunk to Anthropic streaming events +func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAIResponseToAnthropicParams) []string { + root := gjson.ParseBytes(rawJSON) + var results []string + + // Initialize parameters if needed + if param.MessageID == "" { + param.MessageID = root.Get("id").String() + } + if param.Model == "" { + param.Model = root.Get("model").String() + } + if param.CreatedAt == 0 { + param.CreatedAt = root.Get("created").Int() + } + + // Check if this is the first chunk (has role) + if delta := root.Get("choices.0.delta"); delta.Exists() { + if role := delta.Get("role"); role.Exists() && role.String() == "assistant" { + // Send message_start event + messageStart := map[string]interface{}{ + "type": "message_start", + "message": map[string]interface{}{ + "id": param.MessageID, + "type": "message", + "role": "assistant", + "model": param.Model, + "content": []interface{}{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]interface{}{ + "input_tokens": 0, + "output_tokens": 0, + }, + }, + } + messageStartJSON, _ := json.Marshal(messageStart) + results = append(results, "event: message_start\ndata: "+string(messageStartJSON)+"\n\n") + + // Don't send content_block_start for text here - wait for actual content + } + + // Handle content delta + if content := delta.Get("content"); content.Exists() && content.String() != "" { + // Send content_block_start for text if not already sent + if !param.TextContentBlockStarted { + contentBlockStart := map[string]interface{}{ + "type": "content_block_start", + "index": 0, + "content_block": map[string]interface{}{ + "type": "text", + "text": "", + }, + } + contentBlockStartJSON, _ := json.Marshal(contentBlockStart) + results = append(results, "event: content_block_start\ndata: "+string(contentBlockStartJSON)+"\n\n") + param.TextContentBlockStarted = true + } + + contentDelta := map[string]interface{}{ + "type": "content_block_delta", + "index": 0, + "delta": map[string]interface{}{ + "type": "text_delta", + "text": content.String(), + }, + } + contentDeltaJSON, _ := json.Marshal(contentDelta) + results = append(results, "event: content_block_delta\ndata: "+string(contentDeltaJSON)+"\n\n") + + // Accumulate content + param.ContentAccumulator.WriteString(content.String()) + } + + // Handle tool calls + if toolCalls := delta.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() { + if param.ToolCallsAccumulator == nil { + param.ToolCallsAccumulator = make(map[int]*ToolCallAccumulator) + } + + toolCalls.ForEach(func(_, toolCall gjson.Result) bool { + index := int(toolCall.Get("index").Int()) + + // Initialize accumulator if needed + if _, exists := param.ToolCallsAccumulator[index]; !exists { + param.ToolCallsAccumulator[index] = &ToolCallAccumulator{} + } + + accumulator := param.ToolCallsAccumulator[index] + + // Handle tool call ID + if id := toolCall.Get("id"); id.Exists() { + accumulator.ID = id.String() + } + + // Handle function name + if function := toolCall.Get("function"); function.Exists() { + if name := function.Get("name"); name.Exists() { + accumulator.Name = name.String() + + // Send content_block_start for tool_use + contentBlockStart := map[string]interface{}{ + "type": "content_block_start", + "index": index + 1, // Offset by 1 since text is at index 0 + "content_block": map[string]interface{}{ + "type": "tool_use", + "id": accumulator.ID, + "name": accumulator.Name, + "input": map[string]interface{}{}, + }, + } + contentBlockStartJSON, _ := json.Marshal(contentBlockStart) + results = append(results, "event: content_block_start\ndata: "+string(contentBlockStartJSON)+"\n\n") + } + + // Handle function arguments + if args := function.Get("arguments"); args.Exists() { + argsText := args.String() + accumulator.Arguments.WriteString(argsText) + + // Send input_json_delta + inputDelta := map[string]interface{}{ + "type": "content_block_delta", + "index": index + 1, + "delta": map[string]interface{}{ + "type": "input_json_delta", + "partial_json": argsText, + }, + } + inputDeltaJSON, _ := json.Marshal(inputDelta) + results = append(results, "event: content_block_delta\ndata: "+string(inputDeltaJSON)+"\n\n") + } + } + + return true + }) + } + } + + // Handle finish_reason (but don't send message_delta/message_stop yet) + if finishReason := root.Get("choices.0.finish_reason"); finishReason.Exists() && finishReason.String() != "" { + reason := finishReason.String() + param.FinishReason = reason + + // Send content_block_stop for text if text content block was started + if param.TextContentBlockStarted && !param.ContentBlocksStopped { + contentBlockStop := map[string]interface{}{ + "type": "content_block_stop", + "index": 0, + } + contentBlockStopJSON, _ := json.Marshal(contentBlockStop) + results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n") + } + + // Send content_block_stop for any tool calls + if !param.ContentBlocksStopped { + for index := range param.ToolCallsAccumulator { + contentBlockStop := map[string]interface{}{ + "type": "content_block_stop", + "index": index + 1, + } + contentBlockStopJSON, _ := json.Marshal(contentBlockStop) + results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n") + } + param.ContentBlocksStopped = true + } + + // Don't send message_delta here - wait for usage info or [DONE] + } + + // Handle usage information separately (this comes in a later chunk) + // Only process if usage has actual values (not null) + if usage := root.Get("usage"); usage.Exists() && usage.Type != gjson.Null && param.FinishReason != "" { + // Check if usage has actual token counts + promptTokens := usage.Get("prompt_tokens") + completionTokens := usage.Get("completion_tokens") + + if promptTokens.Exists() && completionTokens.Exists() { + // Send message_delta with usage + messageDelta := map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{ + "stop_reason": mapOpenAIFinishReasonToAnthropic(param.FinishReason), + "stop_sequence": nil, + }, + "usage": map[string]interface{}{ + "input_tokens": promptTokens.Int(), + "output_tokens": completionTokens.Int(), + }, + } + + messageDeltaJSON, _ := json.Marshal(messageDelta) + results = append(results, "event: message_delta\ndata: "+string(messageDeltaJSON)+"\n\n") + param.MessageDeltaSent = true + } + } + + return results +} + +// convertOpenAIDoneToAnthropic handles the [DONE] marker and sends final events +func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams) []string { + var results []string + + // If we haven't sent message_delta yet (no usage info was received), send it now + if param.FinishReason != "" && !param.MessageDeltaSent { + messageDelta := map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{ + "stop_reason": mapOpenAIFinishReasonToAnthropic(param.FinishReason), + "stop_sequence": nil, + }, + } + + messageDeltaJSON, _ := json.Marshal(messageDelta) + results = append(results, "event: message_delta\ndata: "+string(messageDeltaJSON)+"\n\n") + param.MessageDeltaSent = true + } + + // Send message_stop + results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n") + + return results +} + +// convertOpenAINonStreamingToAnthropic converts OpenAI non-streaming response to Anthropic format +func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string { + root := gjson.ParseBytes(rawJSON) + + // Build Anthropic response + response := map[string]interface{}{ + "id": root.Get("id").String(), + "type": "message", + "role": "assistant", + "model": root.Get("model").String(), + "content": []interface{}{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]interface{}{ + "input_tokens": 0, + "output_tokens": 0, + }, + } + + // Process message content and tool calls + var contentBlocks []interface{} + + if choices := root.Get("choices"); choices.Exists() && choices.IsArray() { + choice := choices.Array()[0] // Take first choice + + // Handle text content + if content := choice.Get("message.content"); content.Exists() && content.String() != "" { + textBlock := map[string]interface{}{ + "type": "text", + "text": content.String(), + } + contentBlocks = append(contentBlocks, textBlock) + } + + // Handle tool calls + if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() { + toolCalls.ForEach(func(_, toolCall gjson.Result) bool { + toolUseBlock := map[string]interface{}{ + "type": "tool_use", + "id": toolCall.Get("id").String(), + "name": toolCall.Get("function.name").String(), + } + + // Parse arguments + argsStr := toolCall.Get("function.arguments").String() + if argsStr != "" { + var args interface{} + if err := json.Unmarshal([]byte(argsStr), &args); err == nil { + toolUseBlock["input"] = args + } else { + toolUseBlock["input"] = map[string]interface{}{} + } + } else { + toolUseBlock["input"] = map[string]interface{}{} + } + + contentBlocks = append(contentBlocks, toolUseBlock) + return true + }) + } + + // Set stop reason + if finishReason := choice.Get("finish_reason"); finishReason.Exists() { + response["stop_reason"] = mapOpenAIFinishReasonToAnthropic(finishReason.String()) + } + } + + response["content"] = contentBlocks + + // Set usage information + if usage := root.Get("usage"); usage.Exists() { + response["usage"] = map[string]interface{}{ + "input_tokens": usage.Get("prompt_tokens").Int(), + "output_tokens": usage.Get("completion_tokens").Int(), + } + } + + responseJSON, _ := json.Marshal(response) + return []string{string(responseJSON)} +} + +// mapOpenAIFinishReasonToAnthropic maps OpenAI finish reasons to Anthropic equivalents +func mapOpenAIFinishReasonToAnthropic(openAIReason string) string { + switch openAIReason { + case "stop": + return "end_turn" + case "length": + return "max_tokens" + case "tool_calls": + return "tool_use" + case "content_filter": + return "end_turn" // Anthropic doesn't have direct equivalent + case "function_call": // Legacy OpenAI + return "tool_use" + default: + return "end_turn" + } +} diff --git a/internal/translator/openai/gemini/openai_gemini_request.go b/internal/translator/openai/gemini/openai_gemini_request.go new file mode 100644 index 00000000..d535542e --- /dev/null +++ b/internal/translator/openai/gemini/openai_gemini_request.go @@ -0,0 +1,359 @@ +// Package gemini provides request translation functionality for Gemini to OpenAI API. +// It handles parsing and transforming Gemini API requests into OpenAI Chat Completions API format, +// extracting model information, generation config, message contents, and tool declarations. +// The package performs JSON data transformation to ensure compatibility +// between Gemini API format and OpenAI API's expected format. +package gemini + +import ( + "crypto/rand" + "encoding/json" + "math/big" + "strings" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// ConvertGeminiRequestToOpenAI parses and transforms a Gemini API request into OpenAI Chat Completions API format. +// It extracts the model name, generation config, message contents, and tool declarations +// from the raw JSON request and returns them in the format expected by the OpenAI API. +func ConvertGeminiRequestToOpenAI(rawJSON []byte) string { + // Base OpenAI Chat Completions API template + out := `{"model":"","messages":[]}` + + root := gjson.ParseBytes(rawJSON) + + // Helper for generating tool call IDs in the form: call_ + genToolCallID := func() string { + const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + var b strings.Builder + // 24 chars random suffix + for i := 0; i < 24; i++ { + n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters)))) + b.WriteByte(letters[n.Int64()]) + } + return "call_" + b.String() + } + + // Model mapping + if model := root.Get("model"); model.Exists() { + modelStr := model.String() + out, _ = sjson.Set(out, "model", modelStr) + } + + // Generation config mapping + if genConfig := root.Get("generationConfig"); genConfig.Exists() { + // Temperature + if temp := genConfig.Get("temperature"); temp.Exists() { + out, _ = sjson.Set(out, "temperature", temp.Float()) + } + + // Max tokens + if maxTokens := genConfig.Get("maxOutputTokens"); maxTokens.Exists() { + out, _ = sjson.Set(out, "max_tokens", maxTokens.Int()) + } + + // Top P + if topP := genConfig.Get("topP"); topP.Exists() { + out, _ = sjson.Set(out, "top_p", topP.Float()) + } + + // Top K (OpenAI doesn't have direct equivalent, but we can map it) + if topK := genConfig.Get("topK"); topK.Exists() { + // Store as custom parameter for potential use + out, _ = sjson.Set(out, "top_k", topK.Int()) + } + + // Stop sequences + if stopSequences := genConfig.Get("stopSequences"); stopSequences.Exists() && stopSequences.IsArray() { + var stops []string + stopSequences.ForEach(func(_, value gjson.Result) bool { + stops = append(stops, value.String()) + return true + }) + if len(stops) > 0 { + out, _ = sjson.Set(out, "stop", stops) + } + } + } + + // Stream parameter + if stream := root.Get("stream"); stream.Exists() { + out, _ = sjson.Set(out, "stream", stream.Bool()) + } + + // Process contents (Gemini messages) -> OpenAI messages + var openAIMessages []interface{} + var toolCallIDs []string // Track tool call IDs for matching with tool results + + if contents := root.Get("contents"); contents.Exists() && contents.IsArray() { + contents.ForEach(func(_, content gjson.Result) bool { + role := content.Get("role").String() + parts := content.Get("parts") + + // Convert role: model -> assistant + if role == "model" { + role = "assistant" + } + + // Create OpenAI message + msg := map[string]interface{}{ + "role": role, + "content": "", + } + + var contentParts []string + var toolCalls []interface{} + + if parts.Exists() && parts.IsArray() { + parts.ForEach(func(_, part gjson.Result) bool { + // Handle text parts + if text := part.Get("text"); text.Exists() { + contentParts = append(contentParts, text.String()) + } + + // Handle function calls (Gemini) -> tool calls (OpenAI) + if functionCall := part.Get("functionCall"); functionCall.Exists() { + toolCallID := genToolCallID() + toolCallIDs = append(toolCallIDs, toolCallID) + + toolCall := map[string]interface{}{ + "id": toolCallID, + "type": "function", + "function": map[string]interface{}{ + "name": functionCall.Get("name").String(), + }, + } + + // Convert args to arguments JSON string + if args := functionCall.Get("args"); args.Exists() { + argsJSON, _ := json.Marshal(args.Value()) + toolCall["function"].(map[string]interface{})["arguments"] = string(argsJSON) + } else { + toolCall["function"].(map[string]interface{})["arguments"] = "{}" + } + + toolCalls = append(toolCalls, toolCall) + } + + // Handle function responses (Gemini) -> tool role messages (OpenAI) + if functionResponse := part.Get("functionResponse"); functionResponse.Exists() { + // Create tool message for function response + toolMsg := map[string]interface{}{ + "role": "tool", + "tool_call_id": "", // Will be set based on context + "content": "", + } + + // Convert response.content to JSON string + if response := functionResponse.Get("response"); response.Exists() { + if content = response.Get("content"); content.Exists() { + // Use the content field from the response + contentJSON, _ := json.Marshal(content.Value()) + toolMsg["content"] = string(contentJSON) + } else { + // Fallback to entire response + responseJSON, _ := json.Marshal(response.Value()) + toolMsg["content"] = string(responseJSON) + } + } + + // Try to match with previous tool call ID + _ = functionResponse.Get("name").String() // functionName not used for now + if len(toolCallIDs) > 0 { + // Use the last tool call ID (simple matching by function name) + // In a real implementation, you might want more sophisticated matching + toolMsg["tool_call_id"] = toolCallIDs[len(toolCallIDs)-1] + } else { + // Generate a tool call ID if none available + toolMsg["tool_call_id"] = genToolCallID() + } + + openAIMessages = append(openAIMessages, toolMsg) + } + + return true + }) + } + + // Set content + if len(contentParts) > 0 { + msg["content"] = strings.Join(contentParts, "") + } + + // Set tool calls if any + if len(toolCalls) > 0 { + msg["tool_calls"] = toolCalls + } + + openAIMessages = append(openAIMessages, msg) + + // switch role { + // case "user", "model": + // // Convert role: model -> assistant + // if role == "model" { + // role = "assistant" + // } + // + // // Create OpenAI message + // msg := map[string]interface{}{ + // "role": role, + // "content": "", + // } + // + // var contentParts []string + // var toolCalls []interface{} + // + // if parts.Exists() && parts.IsArray() { + // parts.ForEach(func(_, part gjson.Result) bool { + // // Handle text parts + // if text := part.Get("text"); text.Exists() { + // contentParts = append(contentParts, text.String()) + // } + // + // // Handle function calls (Gemini) -> tool calls (OpenAI) + // if functionCall := part.Get("functionCall"); functionCall.Exists() { + // toolCallID := genToolCallID() + // toolCallIDs = append(toolCallIDs, toolCallID) + // + // toolCall := map[string]interface{}{ + // "id": toolCallID, + // "type": "function", + // "function": map[string]interface{}{ + // "name": functionCall.Get("name").String(), + // }, + // } + // + // // Convert args to arguments JSON string + // if args := functionCall.Get("args"); args.Exists() { + // argsJSON, _ := json.Marshal(args.Value()) + // toolCall["function"].(map[string]interface{})["arguments"] = string(argsJSON) + // } else { + // toolCall["function"].(map[string]interface{})["arguments"] = "{}" + // } + // + // toolCalls = append(toolCalls, toolCall) + // } + // + // return true + // }) + // } + // + // // Set content + // if len(contentParts) > 0 { + // msg["content"] = strings.Join(contentParts, "") + // } + // + // // Set tool calls if any + // if len(toolCalls) > 0 { + // msg["tool_calls"] = toolCalls + // } + // + // openAIMessages = append(openAIMessages, msg) + // + // case "function": + // // Handle Gemini function role -> OpenAI tool role + // if parts.Exists() && parts.IsArray() { + // parts.ForEach(func(_, part gjson.Result) bool { + // // Handle function responses (Gemini) -> tool role messages (OpenAI) + // if functionResponse := part.Get("functionResponse"); functionResponse.Exists() { + // // Create tool message for function response + // toolMsg := map[string]interface{}{ + // "role": "tool", + // "tool_call_id": "", // Will be set based on context + // "content": "", + // } + // + // // Convert response.content to JSON string + // if response := functionResponse.Get("response"); response.Exists() { + // if content = response.Get("content"); content.Exists() { + // // Use the content field from the response + // contentJSON, _ := json.Marshal(content.Value()) + // toolMsg["content"] = string(contentJSON) + // } else { + // // Fallback to entire response + // responseJSON, _ := json.Marshal(response.Value()) + // toolMsg["content"] = string(responseJSON) + // } + // } + // + // // Try to match with previous tool call ID + // _ = functionResponse.Get("name").String() // functionName not used for now + // if len(toolCallIDs) > 0 { + // // Use the last tool call ID (simple matching by function name) + // // In a real implementation, you might want more sophisticated matching + // toolMsg["tool_call_id"] = toolCallIDs[len(toolCallIDs)-1] + // } else { + // // Generate a tool call ID if none available + // toolMsg["tool_call_id"] = genToolCallID() + // } + // + // openAIMessages = append(openAIMessages, toolMsg) + // } + // + // return true + // }) + // } + // } + return true + }) + } + + // Set messages + if len(openAIMessages) > 0 { + messagesJSON, _ := json.Marshal(openAIMessages) + out, _ = sjson.SetRaw(out, "messages", string(messagesJSON)) + } + + // Tools mapping: Gemini tools -> OpenAI tools + if tools := root.Get("tools"); tools.Exists() && tools.IsArray() { + var openAITools []interface{} + tools.ForEach(func(_, tool gjson.Result) bool { + if functionDeclarations := tool.Get("functionDeclarations"); functionDeclarations.Exists() && functionDeclarations.IsArray() { + functionDeclarations.ForEach(func(_, funcDecl gjson.Result) bool { + openAITool := map[string]interface{}{ + "type": "function", + "function": map[string]interface{}{ + "name": funcDecl.Get("name").String(), + "description": funcDecl.Get("description").String(), + }, + } + + // Convert parameters schema + if parameters := funcDecl.Get("parameters"); parameters.Exists() { + openAITool["function"].(map[string]interface{})["parameters"] = parameters.Value() + } else if parameters = funcDecl.Get("parametersJsonSchema"); parameters.Exists() { + openAITool["function"].(map[string]interface{})["parameters"] = parameters.Value() + } + + openAITools = append(openAITools, openAITool) + return true + }) + } + return true + }) + + if len(openAITools) > 0 { + toolsJSON, _ := json.Marshal(openAITools) + out, _ = sjson.SetRaw(out, "tools", string(toolsJSON)) + } + } + + // Tool choice mapping (Gemini doesn't have direct equivalent, but we can handle it) + if toolConfig := root.Get("toolConfig"); toolConfig.Exists() { + if functionCallingConfig := toolConfig.Get("functionCallingConfig"); functionCallingConfig.Exists() { + mode := functionCallingConfig.Get("mode").String() + switch mode { + case "NONE": + out, _ = sjson.Set(out, "tool_choice", "none") + case "AUTO": + out, _ = sjson.Set(out, "tool_choice", "auto") + case "ANY": + out, _ = sjson.Set(out, "tool_choice", "required") + } + } + } + + return out +} diff --git a/internal/translator/openai/gemini/openai_gemini_response.go b/internal/translator/openai/gemini/openai_gemini_response.go new file mode 100644 index 00000000..17226f11 --- /dev/null +++ b/internal/translator/openai/gemini/openai_gemini_response.go @@ -0,0 +1,353 @@ +// Package gemini provides response translation functionality for OpenAI to Gemini API. +// This package handles the conversion of OpenAI Chat Completions API responses into Gemini API-compatible +// JSON format, transforming streaming events and non-streaming responses into the format +// expected by Gemini API clients. It supports both streaming and non-streaming modes, +// handling text content, tool calls, and usage metadata appropriately. +package gemini + +import ( + "encoding/json" + "strings" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// ConvertOpenAIResponseToGeminiParams holds parameters for response conversion +type ConvertOpenAIResponseToGeminiParams struct { + // Tool calls accumulator for streaming + ToolCallsAccumulator map[int]*ToolCallAccumulator + // Content accumulator for streaming + ContentAccumulator strings.Builder + // Track if this is the first chunk + IsFirstChunk bool +} + +// ToolCallAccumulator holds the state for accumulating tool call data +type ToolCallAccumulator struct { + ID string + Name string + Arguments strings.Builder +} + +// ConvertOpenAIResponseToGemini converts OpenAI Chat Completions streaming response format to Gemini API format. +// This function processes OpenAI streaming chunks and transforms them into Gemini-compatible JSON responses. +// It handles text content, tool calls, and usage metadata, outputting responses that match the Gemini API format. +func ConvertOpenAIResponseToGemini(rawJSON []byte, param *ConvertOpenAIResponseToGeminiParams) []string { + // Handle [DONE] marker + if strings.TrimSpace(string(rawJSON)) == "[DONE]" { + return []string{} + } + + root := gjson.ParseBytes(rawJSON) + + // Initialize accumulators if needed + if param.ToolCallsAccumulator == nil { + param.ToolCallsAccumulator = make(map[int]*ToolCallAccumulator) + } + + // Process choices + if choices := root.Get("choices"); choices.Exists() && choices.IsArray() { + // Handle empty choices array (usage-only chunk) + if len(choices.Array()) == 0 { + // This is a usage-only chunk, handle usage and return + if usage := root.Get("usage"); usage.Exists() { + template := `{"candidates":[],"usageMetadata":{}}` + + // Set model if available + if model := root.Get("model"); model.Exists() { + template, _ = sjson.Set(template, "model", model.String()) + } + + usageObj := map[string]interface{}{ + "promptTokenCount": usage.Get("prompt_tokens").Int(), + "candidatesTokenCount": usage.Get("completion_tokens").Int(), + "totalTokenCount": usage.Get("total_tokens").Int(), + } + template, _ = sjson.Set(template, "usageMetadata", usageObj) + return []string{template} + } + return []string{} + } + + var results []string + + choices.ForEach(func(choiceIndex, choice gjson.Result) bool { + // Base Gemini response template + template := `{"candidates":[{"content":{"parts":[],"role":"model"},"finishReason":"STOP","index":0}]}` + + // Set model if available + if model := root.Get("model"); model.Exists() { + template, _ = sjson.Set(template, "model", model.String()) + } + + _ = int(choice.Get("index").Int()) // choiceIdx not used in streaming + delta := choice.Get("delta") + + // Handle role (only in first chunk) + if role := delta.Get("role"); role.Exists() && param.IsFirstChunk { + // OpenAI assistant -> Gemini model + if role.String() == "assistant" { + template, _ = sjson.Set(template, "candidates.0.content.role", "model") + } + param.IsFirstChunk = false + results = append(results, template) + return true + } + + // Handle content delta + if content := delta.Get("content"); content.Exists() && content.String() != "" { + contentText := content.String() + param.ContentAccumulator.WriteString(contentText) + + // Create text part for this delta + parts := []interface{}{ + map[string]interface{}{ + "text": contentText, + }, + } + template, _ = sjson.Set(template, "candidates.0.content.parts", parts) + results = append(results, template) + return true + } + + // Handle tool calls delta + if toolCalls := delta.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() { + toolCalls.ForEach(func(_, toolCall gjson.Result) bool { + toolIndex := int(toolCall.Get("index").Int()) + toolID := toolCall.Get("id").String() + toolType := toolCall.Get("type").String() + + if toolType == "function" { + function := toolCall.Get("function") + functionName := function.Get("name").String() + functionArgs := function.Get("arguments").String() + + // Initialize accumulator if needed + if _, exists := param.ToolCallsAccumulator[toolIndex]; !exists { + param.ToolCallsAccumulator[toolIndex] = &ToolCallAccumulator{ + ID: toolID, + Name: functionName, + } + } + + // Update ID if provided + if toolID != "" { + param.ToolCallsAccumulator[toolIndex].ID = toolID + } + + // Update name if provided + if functionName != "" { + param.ToolCallsAccumulator[toolIndex].Name = functionName + } + + // Accumulate arguments + if functionArgs != "" { + param.ToolCallsAccumulator[toolIndex].Arguments.WriteString(functionArgs) + } + } + return true + }) + + // Don't output anything for tool call deltas - wait for completion + return true + } + + // Handle finish reason + if finishReason := choice.Get("finish_reason"); finishReason.Exists() { + geminiFinishReason := mapOpenAIFinishReasonToGemini(finishReason.String()) + template, _ = sjson.Set(template, "candidates.0.finishReason", geminiFinishReason) + + // If we have accumulated tool calls, output them now + if len(param.ToolCallsAccumulator) > 0 { + var parts []interface{} + for _, accumulator := range param.ToolCallsAccumulator { + argsStr := accumulator.Arguments.String() + var argsMap map[string]interface{} + + if argsStr != "" && argsStr != "{}" { + // Handle malformed JSON by trying to fix common issues + fixedArgs := argsStr + // Fix unquoted keys and values (common in the sample) + if strings.Contains(fixedArgs, "北京") && !strings.Contains(fixedArgs, "\"北京\"") { + fixedArgs = strings.ReplaceAll(fixedArgs, "北京", "\"北京\"") + } + if strings.Contains(fixedArgs, "celsius") && !strings.Contains(fixedArgs, "\"celsius\"") { + fixedArgs = strings.ReplaceAll(fixedArgs, "celsius", "\"celsius\"") + } + + if err := json.Unmarshal([]byte(fixedArgs), &argsMap); err != nil { + // If still fails, try to parse as raw string + if err2 := json.Unmarshal([]byte("\""+argsStr+"\""), &argsMap); err2 != nil { + // Last resort: use empty object + argsMap = map[string]interface{}{} + } + } + } else { + argsMap = map[string]interface{}{} + } + + functionCallPart := map[string]interface{}{ + "functionCall": map[string]interface{}{ + "name": accumulator.Name, + "args": argsMap, + }, + } + parts = append(parts, functionCallPart) + } + + if len(parts) > 0 { + template, _ = sjson.Set(template, "candidates.0.content.parts", parts) + } + + // Clear accumulators + param.ToolCallsAccumulator = make(map[int]*ToolCallAccumulator) + } + + results = append(results, template) + return true + } + + // Handle usage information + if usage := root.Get("usage"); usage.Exists() { + usageObj := map[string]interface{}{ + "promptTokenCount": usage.Get("prompt_tokens").Int(), + "candidatesTokenCount": usage.Get("completion_tokens").Int(), + "totalTokenCount": usage.Get("total_tokens").Int(), + } + template, _ = sjson.Set(template, "usageMetadata", usageObj) + results = append(results, template) + return true + } + + return true + }) + return results + } + return []string{} +} + +// mapOpenAIFinishReasonToGemini maps OpenAI finish reasons to Gemini finish reasons +func mapOpenAIFinishReasonToGemini(openAIReason string) string { + switch openAIReason { + case "stop": + return "STOP" + case "length": + return "MAX_TOKENS" + case "tool_calls": + return "STOP" // Gemini doesn't have a specific tool_calls finish reason + case "content_filter": + return "SAFETY" + default: + return "STOP" + } +} + +// ConvertOpenAINonStreamResponseToGemini converts OpenAI non-streaming response to Gemini format +func ConvertOpenAINonStreamResponseToGemini(rawJSON []byte) string { + root := gjson.ParseBytes(rawJSON) + + // Base Gemini response template + out := `{"candidates":[{"content":{"parts":[],"role":"model"},"finishReason":"STOP","index":0}]}` + + // Set model if available + if model := root.Get("model"); model.Exists() { + out, _ = sjson.Set(out, "model", model.String()) + } + + // Process choices + if choices := root.Get("choices"); choices.Exists() && choices.IsArray() { + choices.ForEach(func(choiceIndex, choice gjson.Result) bool { + choiceIdx := int(choice.Get("index").Int()) + message := choice.Get("message") + + // Set role + if role := message.Get("role"); role.Exists() { + if role.String() == "assistant" { + out, _ = sjson.Set(out, "candidates.0.content.role", "model") + } + } + + var parts []interface{} + + // Handle content first + if content := message.Get("content"); content.Exists() && content.String() != "" { + parts = append(parts, map[string]interface{}{ + "text": content.String(), + }) + } + + // Handle tool calls + if toolCalls := message.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() { + toolCalls.ForEach(func(_, toolCall gjson.Result) bool { + if toolCall.Get("type").String() == "function" { + function := toolCall.Get("function") + functionName := function.Get("name").String() + functionArgs := function.Get("arguments").String() + + // Parse arguments + var argsMap map[string]interface{} + if functionArgs != "" && functionArgs != "{}" { + // Handle malformed JSON by trying to fix common issues + fixedArgs := functionArgs + // Fix unquoted keys and values (common in the sample) + if strings.Contains(fixedArgs, "北京") && !strings.Contains(fixedArgs, "\"北京\"") { + fixedArgs = strings.ReplaceAll(fixedArgs, "北京", "\"北京\"") + } + if strings.Contains(fixedArgs, "celsius") && !strings.Contains(fixedArgs, "\"celsius\"") { + fixedArgs = strings.ReplaceAll(fixedArgs, "celsius", "\"celsius\"") + } + + if err := json.Unmarshal([]byte(fixedArgs), &argsMap); err != nil { + // If still fails, try to parse as raw string + if err2 := json.Unmarshal([]byte("\""+functionArgs+"\""), &argsMap); err2 != nil { + // Last resort: use empty object + argsMap = map[string]interface{}{} + } + } + } else { + argsMap = map[string]interface{}{} + } + + functionCallPart := map[string]interface{}{ + "functionCall": map[string]interface{}{ + "name": functionName, + "args": argsMap, + }, + } + parts = append(parts, functionCallPart) + } + return true + }) + } + + // Set parts + if len(parts) > 0 { + out, _ = sjson.Set(out, "candidates.0.content.parts", parts) + } + + // Handle finish reason + if finishReason := choice.Get("finish_reason"); finishReason.Exists() { + geminiFinishReason := mapOpenAIFinishReasonToGemini(finishReason.String()) + out, _ = sjson.Set(out, "candidates.0.finishReason", geminiFinishReason) + } + + // Set index + out, _ = sjson.Set(out, "candidates.0.index", choiceIdx) + + return true + }) + } + + // Handle usage information + if usage := root.Get("usage"); usage.Exists() { + usageObj := map[string]interface{}{ + "promptTokenCount": usage.Get("prompt_tokens").Int(), + "candidatesTokenCount": usage.Get("completion_tokens").Int(), + "totalTokenCount": usage.Get("total_tokens").Int(), + } + out, _ = sjson.Set(out, "usageMetadata", usageObj) + } + + return out +} diff --git a/internal/util/provider.go b/internal/util/provider.go index bcebe30a..3e330e36 100644 --- a/internal/util/provider.go +++ b/internal/util/provider.go @@ -23,6 +23,8 @@ func GetProviderName(modelName string) string { return "gpt" } else if strings.HasPrefix(modelName, "claude") { return "claude" + } else if strings.HasPrefix(modelName, "qwen") { + return "qwen" } return "unknow" } diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 05921d17..d00c65c8 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -19,6 +19,7 @@ import ( "github.com/luispater/CLIProxyAPI/internal/auth/claude" "github.com/luispater/CLIProxyAPI/internal/auth/codex" "github.com/luispater/CLIProxyAPI/internal/auth/gemini" + "github.com/luispater/CLIProxyAPI/internal/auth/qwen" "github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/config" "github.com/luispater/CLIProxyAPI/internal/util" @@ -281,6 +282,20 @@ func (w *Watcher) reloadClients() { } else { log.Errorf(" failed to decode token file %s: %v", path, err) } + } else if tokenType == "qwen" { + var ts qwen.QwenTokenStorage + if err = json.Unmarshal(data, &ts); err == nil { + // For each valid token, create an authenticated client + log.Debugf(" initializing qwen authentication for token from %s...", filepath.Base(path)) + qwenClient := client.NewQwenClient(cfg, &ts) + log.Debugf(" authentication successful for token from %s", filepath.Base(path)) + + // Add the new client to the pool + newClients = append(newClients, qwenClient) + successfulAuthCount++ + } else { + log.Errorf(" failed to decode token file %s: %v", path, err) + } } } return nil