diff --git a/cmd/server/main.go b/cmd/server/main.go index df0d8350..cea5e6ee 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -1,3 +1,6 @@ +// Package main provides the entry point for the CLI Proxy API server. +// This server acts as a proxy that provides OpenAI/Gemini/Claude compatible API interfaces +// for CLI models, allowing CLI models to be used with tools and libraries designed for standard AI APIs. package main import ( diff --git a/internal/api/claude-code-handlers.go b/internal/api/handlers/claude/code-handlers.go similarity index 72% rename from internal/api/claude-code-handlers.go rename to internal/api/handlers/claude/code-handlers.go index ff825a7b..4427cc43 100644 --- a/internal/api/claude-code-handlers.go +++ b/internal/api/handlers/claude/code-handlers.go @@ -1,10 +1,17 @@ -package api +// Package claude provides HTTP handlers for Claude API code-related functionality. +// This package implements Claude-compatible streaming chat completions with sophisticated +// client rotation and quota management systems to ensure high availability and optimal +// resource utilization across multiple backend clients. It handles request translation +// between Claude API format and the underlying Gemini backend, providing seamless +// API compatibility while maintaining robust error handling and connection management. +package claude import ( "context" "fmt" "github.com/gin-gonic/gin" - "github.com/luispater/CLIProxyAPI/internal/api/translator" + "github.com/luispater/CLIProxyAPI/internal/api/handlers" + "github.com/luispater/CLIProxyAPI/internal/api/translator/claude/code" "github.com/luispater/CLIProxyAPI/internal/client" log "github.com/sirupsen/logrus" "net/http" @@ -12,16 +19,30 @@ import ( "time" ) +// ClaudeCodeAPIHandlers contains the handlers for Claude API endpoints. +// It holds a pool of clients to interact with the backend service. +type ClaudeCodeAPIHandlers struct { + *handlers.APIHandlers +} + +// NewClaudeCodeAPIHandlers creates a new Claude API handlers instance. +// It takes an APIHandlers instance as input and returns a ClaudeCodeAPIHandlers. +func NewClaudeCodeAPIHandlers(apiHandlers *handlers.APIHandlers) *ClaudeCodeAPIHandlers { + return &ClaudeCodeAPIHandlers{ + APIHandlers: apiHandlers, + } +} + // ClaudeMessages handles Claude-compatible streaming chat completions. // This function implements a sophisticated client rotation and quota management system // to ensure high availability and optimal resource utilization across multiple backend clients. -func (h *APIHandlers) ClaudeMessages(c *gin.Context) { +func (h *ClaudeCodeAPIHandlers) ClaudeMessages(c *gin.Context) { // Extract raw JSON data from the incoming request - rawJson, err := c.GetRawData() + rawJSON, err := c.GetRawData() // If data retrieval fails, return a 400 Bad Request error. if err != nil { - c.JSON(http.StatusBadRequest, ErrorResponse{ - Error: ErrorDetail{ + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ Message: fmt.Sprintf("Invalid request: %v", err), Type: "invalid_request_error", }, @@ -41,8 +62,8 @@ func (h *APIHandlers) ClaudeMessages(c *gin.Context) { // This is crucial for streaming as it allows immediate sending of data chunks flusher, ok := c.Writer.(http.Flusher) if !ok { - c.JSON(http.StatusInternalServerError, ErrorResponse{ - Error: ErrorDetail{ + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ Message: "Streaming not supported", Type: "server_error", }, @@ -52,7 +73,7 @@ func (h *APIHandlers) ClaudeMessages(c *gin.Context) { // Parse and prepare the Claude request, extracting model name, system instructions, // conversation contents, and available tools from the raw JSON - modelName, systemInstruction, contents, tools := translator.PrepareClaudeRequest(rawJson) + modelName, systemInstruction, contents, tools := code.PrepareClaudeRequest(rawJSON) // Map Claude model names to corresponding Gemini models // This allows the proxy to handle Claude API calls using Gemini backend @@ -79,7 +100,7 @@ func (h *APIHandlers) ClaudeMessages(c *gin.Context) { outLoop: for { var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.getClient(modelName) + cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) _, _ = fmt.Fprint(c.Writer, errorResponse.Error) @@ -105,7 +126,7 @@ outLoop: includeThoughts = !strings.Contains(userAgent[0], "claude-cli") } - respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, systemInstruction, contents, tools, includeThoughts) + respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools, includeThoughts) // Track response state for proper Claude format conversion hasFirstResponse := false @@ -139,16 +160,15 @@ outLoop: flusher.Flush() cliCancel() return - } else { - // Convert the backend response to Claude-compatible format - // This translation layer ensures API compatibility - claudeFormat := translator.ConvertCliToClaude(chunk, isGlAPIKey, hasFirstResponse, &responseType, &responseIndex) - if claudeFormat != "" { - _, _ = c.Writer.Write([]byte(claudeFormat)) - flusher.Flush() // Immediately send the chunk to the client - } - hasFirstResponse = true } + // Convert the backend response to Claude-compatible format + // This translation layer ensures API compatibility + claudeFormat := code.ConvertCliToClaude(chunk, isGlAPIKey, hasFirstResponse, &responseType, &responseIndex) + if claudeFormat != "" { + _, _ = c.Writer.Write([]byte(claudeFormat)) + flusher.Flush() // Immediately send the chunk to the client + } + hasFirstResponse = true // Case 3: Handle errors from the backend // This manages various error conditions and implements retry logic @@ -156,7 +176,7 @@ outLoop: if okError { // Special handling for quota exceeded errors // If configured, attempt to switch to a different project/client - if errInfo.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + if errInfo.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { continue outLoop // Restart the client selection process } else { // Forward other errors directly to the client diff --git a/internal/api/cli-handlers.go b/internal/api/handlers/gemini/cli/cli-handlers.go similarity index 64% rename from internal/api/cli-handlers.go rename to internal/api/handlers/gemini/cli/cli-handlers.go index 75488907..7c1dde77 100644 --- a/internal/api/cli-handlers.go +++ b/internal/api/handlers/gemini/cli/cli-handlers.go @@ -1,10 +1,15 @@ -package api +// Package cli provides HTTP handlers for Gemini CLI API functionality. +// This package implements handlers that process CLI-specific requests for Gemini API operations, +// including content generation and streaming content generation endpoints. +// The handlers restrict access to localhost only and manage communication with the backend service. +package cli import ( "bytes" "context" "fmt" "github.com/gin-gonic/gin" + "github.com/luispater/CLIProxyAPI/internal/api/handlers" "github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/util" log "github.com/sirupsen/logrus" @@ -16,10 +21,26 @@ import ( "time" ) -func (h *APIHandlers) CLIHandler(c *gin.Context) { +// GeminiCLIAPIHandlers contains the handlers for Gemini CLI API endpoints. +// It holds a pool of clients to interact with the backend service. +type GeminiCLIAPIHandlers struct { + *handlers.APIHandlers +} + +// NewGeminiCLIAPIHandlers creates a new Gemini CLI API handlers instance. +// It takes an APIHandlers instance as input and returns a GeminiCLIAPIHandlers. +func NewGeminiCLIAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiCLIAPIHandlers { + return &GeminiCLIAPIHandlers{ + APIHandlers: apiHandlers, + } +} + +// CLIHandler handles CLI-specific requests for Gemini API operations. +// It restricts access to localhost only and routes requests to appropriate internal handlers. +func (h *GeminiCLIAPIHandlers) CLIHandler(c *gin.Context) { if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") { - c.JSON(http.StatusForbidden, ErrorResponse{ - Error: ErrorDetail{ + c.JSON(http.StatusForbidden, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ Message: "CLI reply only allow local access", Type: "forbidden", }, @@ -27,18 +48,18 @@ func (h *APIHandlers) CLIHandler(c *gin.Context) { return } - rawJson, _ := c.GetRawData() + rawJSON, _ := c.GetRawData() requestRawURI := c.Request.URL.Path if requestRawURI == "/v1internal:generateContent" { - h.internalGenerateContent(c, rawJson) + h.internalGenerateContent(c, rawJSON) } else if requestRawURI == "/v1internal:streamGenerateContent" { - h.internalStreamGenerateContent(c, rawJson) + h.internalStreamGenerateContent(c, rawJSON) } else { - reqBody := bytes.NewBuffer(rawJson) + reqBody := bytes.NewBuffer(rawJSON) req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody) if err != nil { - c.JSON(http.StatusBadRequest, ErrorResponse{ - Error: ErrorDetail{ + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ Message: fmt.Sprintf("Invalid request: %v", err), Type: "invalid_request_error", }, @@ -49,15 +70,15 @@ func (h *APIHandlers) CLIHandler(c *gin.Context) { req.Header[key] = value } - httpClient, err := util.SetProxy(h.cfg, &http.Client{}) + httpClient, err := util.SetProxy(h.Cfg, &http.Client{}) if err != nil { log.Fatalf("set proxy failed: %v", err) } resp, err := httpClient.Do(req) if err != nil { - c.JSON(http.StatusBadRequest, ErrorResponse{ - Error: ErrorDetail{ + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ Message: fmt.Sprintf("Invalid request: %v", err), Type: "invalid_request_error", }, @@ -73,8 +94,8 @@ func (h *APIHandlers) CLIHandler(c *gin.Context) { }() bodyBytes, _ := io.ReadAll(resp.Body) - c.JSON(http.StatusBadRequest, ErrorResponse{ - Error: ErrorDetail{ + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ Message: string(bodyBytes), Type: "invalid_request_error", }, @@ -98,8 +119,8 @@ func (h *APIHandlers) CLIHandler(c *gin.Context) { } } -func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []byte) { - alt := h.getAlt(c) +func (h *GeminiCLIAPIHandlers) internalStreamGenerateContent(c *gin.Context, rawJSON []byte) { + alt := h.GetAlt(c) if alt == "" { c.Header("Content-Type", "text/event-stream") @@ -111,8 +132,8 @@ func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []by // Get the http.Flusher interface to manually flush the response. flusher, ok := c.Writer.(http.Flusher) if !ok { - c.JSON(http.StatusInternalServerError, ErrorResponse{ - Error: ErrorDetail{ + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ Message: "Streaming not supported", Type: "server_error", }, @@ -120,7 +141,7 @@ func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []by return } - modelResult := gjson.GetBytes(rawJson, "model") + modelResult := gjson.GetBytes(rawJSON, "model") modelName := modelResult.String() cliCtx, cliCancel := context.WithCancel(context.Background()) @@ -135,7 +156,7 @@ func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []by outLoop: for { var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.getClient(modelName) + cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) _, _ = fmt.Fprint(c.Writer, errorResponse.Error) @@ -150,7 +171,7 @@ outLoop: log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) } // Send the message and receive response chunks and errors via channels. - respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson, "") + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, "") hasFirstResponse := false for { select { @@ -166,20 +187,19 @@ outLoop: if !okStream { cliCancel() return - } else { - hasFirstResponse = true - if cliClient.GetGenerativeLanguageAPIKey() != "" { - chunk, _ = sjson.SetRawBytes(chunk, "response", chunk) - } - _, _ = c.Writer.Write([]byte("data: ")) - _, _ = c.Writer.Write(chunk) - _, _ = c.Writer.Write([]byte("\n\n")) - flusher.Flush() } + hasFirstResponse = true + if cliClient.GetGenerativeLanguageAPIKey() != "" { + chunk, _ = sjson.SetRawBytes(chunk, "response", chunk) + } + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write(chunk) + _, _ = c.Writer.Write([]byte("\n\n")) + flusher.Flush() // Handle errors from the backend. case err, okError := <-errChan: if okError { - if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { continue outLoop } else { c.Status(err.StatusCode) @@ -200,10 +220,10 @@ outLoop: } } -func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) { +func (h *GeminiCLIAPIHandlers) internalGenerateContent(c *gin.Context, rawJSON []byte) { c.Header("Content-Type", "application/json") - modelResult := gjson.GetBytes(rawJson, "model") + modelResult := gjson.GetBytes(rawJSON, "model") modelName := modelResult.String() cliCtx, cliCancel := context.WithCancel(context.Background()) var cliClient *client.Client @@ -215,7 +235,7 @@ func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) { for { var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.getClient(modelName) + cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) _, _ = fmt.Fprint(c.Writer, errorResponse.Error) @@ -229,9 +249,9 @@ func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) { log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) } - resp, err := cliClient.SendRawMessage(cliCtx, rawJson, "") + resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, "") if err != nil { - if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { continue } else { c.Status(err.StatusCode) diff --git a/internal/api/gemini-handlers.go b/internal/api/handlers/gemini/gemini-handlers.go similarity index 67% rename from internal/api/gemini-handlers.go rename to internal/api/handlers/gemini/gemini-handlers.go index 0da2ca6d..160b5daf 100644 --- a/internal/api/gemini-handlers.go +++ b/internal/api/handlers/gemini/gemini-handlers.go @@ -1,10 +1,16 @@ -package api +// Package gemini provides HTTP handlers for Gemini API endpoints. +// This package implements handlers for managing Gemini model operations including +// model listing, content generation, streaming content generation, and token counting. +// It serves as a proxy layer between clients and the Gemini backend service, +// handling request translation, client management, and response processing. +package gemini import ( "context" "fmt" "github.com/gin-gonic/gin" - "github.com/luispater/CLIProxyAPI/internal/api/translator" + "github.com/luispater/CLIProxyAPI/internal/api/handlers" + "github.com/luispater/CLIProxyAPI/internal/api/translator/gemini/cli" "github.com/luispater/CLIProxyAPI/internal/client" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -14,7 +20,23 @@ import ( "time" ) -func (h *APIHandlers) GeminiModels(c *gin.Context) { +// GeminiAPIHandlers contains the handlers for Gemini API endpoints. +// It holds a pool of clients to interact with the backend service. +type GeminiAPIHandlers struct { + *handlers.APIHandlers +} + +// NewGeminiAPIHandlers creates a new Gemini API handlers instance. +// It takes an APIHandlers instance as input and returns a GeminiAPIHandlers. +func NewGeminiAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiAPIHandlers { + return &GeminiAPIHandlers{ + APIHandlers: apiHandlers, + } +} + +// GeminiModels handles the Gemini models listing endpoint. +// It returns a JSON response containing available Gemini models and their specifications. +func (h *GeminiAPIHandlers) GeminiModels(c *gin.Context) { c.Status(http.StatusOK) c.Header("Content-Type", "application/json; charset=UTF-8") _, _ = c.Writer.Write([]byte(`{"models":[{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini `)) @@ -30,13 +52,15 @@ func (h *APIHandlers) GeminiModels(c *gin.Context) { _, _ = c.Writer.Write([]byte(`e":2,"thinking":true}],"nextPageToken":""}`)) } -func (h *APIHandlers) GeminiGetHandler(c *gin.Context) { +// GeminiGetHandler handles GET requests for specific Gemini model information. +// It returns detailed information about a specific Gemini model based on the action parameter. +func (h *GeminiAPIHandlers) GeminiGetHandler(c *gin.Context) { var request struct { Action string `uri:"action" binding:"required"` } if err := c.ShouldBindUri(&request); err != nil { - c.JSON(http.StatusBadRequest, ErrorResponse{ - Error: ErrorDetail{ + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ Message: fmt.Sprintf("Invalid request: %v", err), Type: "invalid_request_error", }, @@ -68,13 +92,15 @@ func (h *APIHandlers) GeminiGetHandler(c *gin.Context) { } } -func (h *APIHandlers) GeminiHandler(c *gin.Context) { +// GeminiHandler handles POST requests for Gemini API operations. +// It routes requests to appropriate handlers based on the action parameter (model:method format). +func (h *GeminiAPIHandlers) GeminiHandler(c *gin.Context) { var request struct { Action string `uri:"action" binding:"required"` } if err := c.ShouldBindUri(&request); err != nil { - c.JSON(http.StatusBadRequest, ErrorResponse{ - Error: ErrorDetail{ + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ Message: fmt.Sprintf("Invalid request: %v", err), Type: "invalid_request_error", }, @@ -83,8 +109,8 @@ func (h *APIHandlers) GeminiHandler(c *gin.Context) { } action := strings.Split(request.Action, ":") if len(action) != 2 { - c.JSON(http.StatusNotFound, ErrorResponse{ - Error: ErrorDetail{ + c.JSON(http.StatusNotFound, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ Message: fmt.Sprintf("%s not found.", c.Request.URL.Path), Type: "invalid_request_error", }, @@ -94,20 +120,20 @@ func (h *APIHandlers) GeminiHandler(c *gin.Context) { modelName := action[0] method := action[1] - rawJson, _ := c.GetRawData() - rawJson, _ = sjson.SetBytes(rawJson, "model", []byte(modelName)) + rawJSON, _ := c.GetRawData() + rawJSON, _ = sjson.SetBytes(rawJSON, "model", []byte(modelName)) if method == "generateContent" { - h.geminiGenerateContent(c, rawJson) + h.geminiGenerateContent(c, rawJSON) } else if method == "streamGenerateContent" { - h.geminiStreamGenerateContent(c, rawJson) + h.geminiStreamGenerateContent(c, rawJSON) } else if method == "countTokens" { - h.geminiCountTokens(c, rawJson) + h.geminiCountTokens(c, rawJSON) } } -func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte) { - alt := h.getAlt(c) +func (h *GeminiAPIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJSON []byte) { + alt := h.GetAlt(c) if alt == "" { c.Header("Content-Type", "text/event-stream") @@ -119,8 +145,8 @@ func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte // Get the http.Flusher interface to manually flush the response. flusher, ok := c.Writer.(http.Flusher) if !ok { - c.JSON(http.StatusInternalServerError, ErrorResponse{ - Error: ErrorDetail{ + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ Message: "Streaming not supported", Type: "server_error", }, @@ -128,7 +154,7 @@ func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte return } - modelResult := gjson.GetBytes(rawJson, "model") + modelResult := gjson.GetBytes(rawJSON, "model") modelName := modelResult.String() cliCtx, cliCancel := context.WithCancel(context.Background()) @@ -143,7 +169,7 @@ func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte outLoop: for { var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.getClient(modelName) + cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) _, _ = fmt.Fprint(c.Writer, errorResponse.Error) @@ -153,21 +179,21 @@ outLoop: } template := "" - parsed := gjson.Parse(string(rawJson)) + parsed := gjson.Parse(string(rawJSON)) contents := parsed.Get("request.contents") if contents.Exists() { - template = string(rawJson) + template = string(rawJSON) } else { template = `{"project":"","request":{},"model":""}` - template, _ = sjson.SetRaw(template, "request", string(rawJson)) + template, _ = sjson.SetRaw(template, "request", string(rawJSON)) template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String()) template, _ = sjson.Delete(template, "request.model") } - template, errFixCLIToolResponse := translator.FixCLIToolResponse(template) + template, errFixCLIToolResponse := cli.FixCLIToolResponse(template) if errFixCLIToolResponse != nil { - c.JSON(http.StatusInternalServerError, ErrorResponse{ - Error: ErrorDetail{ + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ Message: errFixCLIToolResponse.Error(), Type: "server_error", }, @@ -181,7 +207,7 @@ outLoop: template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw) template, _ = sjson.Delete(template, "request.system_instruction") } - rawJson = []byte(template) + rawJSON = []byte(template) if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { log.Debugf("Request use generative language API Key: %s", glAPIKey) @@ -190,7 +216,7 @@ outLoop: } // Send the message and receive response chunks and errors via channels. - respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson, alt) + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, alt) for { select { // Handle client disconnection. @@ -205,41 +231,40 @@ outLoop: if !okStream { cliCancel() return - } else { - if cliClient.GetGenerativeLanguageAPIKey() == "" { - if alt == "" { - responseResult := gjson.GetBytes(chunk, "response") - if responseResult.Exists() { - chunk = []byte(responseResult.Raw) - } - } else { - chunkTemplate := "[]" - responseResult := gjson.ParseBytes(chunk) - if responseResult.IsArray() { - responseResultItems := responseResult.Array() - for i := 0; i < len(responseResultItems); i++ { - responseResultItem := responseResultItems[i] - if responseResultItem.Get("response").Exists() { - chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw) - } + } + if cliClient.GetGenerativeLanguageAPIKey() == "" { + if alt == "" { + responseResult := gjson.GetBytes(chunk, "response") + if responseResult.Exists() { + chunk = []byte(responseResult.Raw) + } + } else { + chunkTemplate := "[]" + responseResult := gjson.ParseBytes(chunk) + if responseResult.IsArray() { + responseResultItems := responseResult.Array() + for i := 0; i < len(responseResultItems); i++ { + responseResultItem := responseResultItems[i] + if responseResultItem.Get("response").Exists() { + chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw) } } - chunk = []byte(chunkTemplate) } + chunk = []byte(chunkTemplate) } - if alt == "" { - _, _ = c.Writer.Write([]byte("data: ")) - _, _ = c.Writer.Write(chunk) - _, _ = c.Writer.Write([]byte("\n\n")) - } else { - _, _ = c.Writer.Write(chunk) - } - flusher.Flush() } + if alt == "" { + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write(chunk) + _, _ = c.Writer.Write([]byte("\n\n")) + } else { + _, _ = c.Writer.Write(chunk) + } + flusher.Flush() // Handle errors from the backend. case err, okError := <-errChan: if okError { - if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { log.Debugf("quota exceeded, switch client") continue outLoop } else { @@ -258,12 +283,12 @@ outLoop: } } -func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) { +func (h *GeminiAPIHandlers) geminiCountTokens(c *gin.Context, rawJSON []byte) { c.Header("Content-Type", "application/json") - alt := h.getAlt(c) - // orgRawJson := rawJson - modelResult := gjson.GetBytes(rawJson, "model") + alt := h.GetAlt(c) + // orgrawJSON := rawJSON + modelResult := gjson.GetBytes(rawJSON, "model") modelName := modelResult.String() cliCtx, cliCancel := context.WithCancel(context.Background()) var cliClient *client.Client @@ -275,7 +300,7 @@ func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) { for { var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.getClient(modelName, false) + cliClient, errorResponse = h.GetClient(modelName, false) if errorResponse != nil { c.Status(errorResponse.StatusCode) _, _ = fmt.Fprint(c.Writer, errorResponse.Error) @@ -289,27 +314,27 @@ func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) { log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) template := `{"request":{}}` - if gjson.GetBytes(rawJson, "generateContentRequest").Exists() { - template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJson, "generateContentRequest").Raw) + if gjson.GetBytes(rawJSON, "generateContentRequest").Exists() { + template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJSON, "generateContentRequest").Raw) template, _ = sjson.Delete(template, "generateContentRequest") - } else if gjson.GetBytes(rawJson, "contents").Exists() { - template, _ = sjson.SetRaw(template, "request.contents", gjson.GetBytes(rawJson, "contents").Raw) + } else if gjson.GetBytes(rawJSON, "contents").Exists() { + template, _ = sjson.SetRaw(template, "request.contents", gjson.GetBytes(rawJSON, "contents").Raw) template, _ = sjson.Delete(template, "contents") } - rawJson = []byte(template) + rawJSON = []byte(template) } - resp, err := cliClient.SendRawTokenCount(cliCtx, rawJson, alt) + resp, err := cliClient.SendRawTokenCount(cliCtx, rawJSON, alt) if err != nil { - if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { continue } else { c.Status(err.StatusCode) _, _ = c.Writer.Write([]byte(err.Error.Error())) cliCancel() // log.Debugf(err.Error.Error()) - // log.Debugf(string(rawJson)) - // log.Debugf(string(orgRawJson)) + // log.Debugf(string(rawJSON)) + // log.Debugf(string(orgrawJSON)) } break } else { @@ -326,12 +351,12 @@ func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) { } } -func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) { +func (h *GeminiAPIHandlers) geminiGenerateContent(c *gin.Context, rawJSON []byte) { c.Header("Content-Type", "application/json") - alt := h.getAlt(c) + alt := h.GetAlt(c) - modelResult := gjson.GetBytes(rawJson, "model") + modelResult := gjson.GetBytes(rawJSON, "model") modelName := modelResult.String() cliCtx, cliCancel := context.WithCancel(context.Background()) var cliClient *client.Client @@ -343,7 +368,7 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) { for { var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.getClient(modelName) + cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) _, _ = fmt.Fprint(c.Writer, errorResponse.Error) @@ -352,21 +377,21 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) { } template := "" - parsed := gjson.Parse(string(rawJson)) + parsed := gjson.Parse(string(rawJSON)) contents := parsed.Get("request.contents") if contents.Exists() { - template = string(rawJson) + template = string(rawJSON) } else { template = `{"project":"","request":{},"model":""}` - template, _ = sjson.SetRaw(template, "request", string(rawJson)) + template, _ = sjson.SetRaw(template, "request", string(rawJSON)) template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String()) template, _ = sjson.Delete(template, "request.model") } - template, errFixCLIToolResponse := translator.FixCLIToolResponse(template) + template, errFixCLIToolResponse := cli.FixCLIToolResponse(template) if errFixCLIToolResponse != nil { - c.JSON(http.StatusInternalServerError, ErrorResponse{ - Error: ErrorDetail{ + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ Message: errFixCLIToolResponse.Error(), Type: "server_error", }, @@ -380,16 +405,16 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) { template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw) template, _ = sjson.Delete(template, "request.system_instruction") } - rawJson = []byte(template) + rawJSON = []byte(template) if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { log.Debugf("Request use generative language API Key: %s", glAPIKey) } else { log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) } - resp, err := cliClient.SendRawMessage(cliCtx, rawJson, alt) + resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, alt) if err != nil { - if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { continue } else { c.Status(err.StatusCode) @@ -410,16 +435,3 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) { } } } - -func (h *APIHandlers) getAlt(c *gin.Context) string { - var alt string - var hasAlt bool - alt, hasAlt = c.GetQuery("alt") - if !hasAlt { - alt, _ = c.GetQuery("$alt") - } - if alt == "sse" { - return "" - } - return alt -} diff --git a/internal/api/handlers/handlers.go b/internal/api/handlers/handlers.go new file mode 100644 index 00000000..f6fe2326 --- /dev/null +++ b/internal/api/handlers/handlers.go @@ -0,0 +1,122 @@ +// Package handlers provides core API handler functionality for the CLI Proxy API server. +// It includes common types, client management, load balancing, and error handling +// shared across all API endpoint handlers (OpenAI, Claude, Gemini). +package handlers + +import ( + "fmt" + "github.com/gin-gonic/gin" + "github.com/luispater/CLIProxyAPI/internal/client" + "github.com/luispater/CLIProxyAPI/internal/config" + log "github.com/sirupsen/logrus" + "sync" +) + +// ErrorResponse represents a standard error response format for the API. +// It contains a single ErrorDetail field. +type ErrorResponse struct { + Error ErrorDetail `json:"error"` +} + +// ErrorDetail provides specific information about an error that occurred. +// It includes a human-readable message, an error type, and an optional error code. +type ErrorDetail struct { + // A human-readable message providing more details about the error. + Message string `json:"message"` + // The type of error that occurred (e.g., "invalid_request_error"). + Type string `json:"type"` + // A short code identifying the error, if applicable. + Code string `json:"code,omitempty"` +} + +// APIHandlers contains the handlers for API endpoints. +// It holds a pool of clients to interact with the backend service. +type APIHandlers struct { + CliClients []*client.Client + Cfg *config.Config + Mutex *sync.Mutex + LastUsedClientIndex int +} + +// NewAPIHandlers creates a new API handlers instance. +// It takes a slice of clients and a debug flag as input. +func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers { + return &APIHandlers{ + CliClients: cliClients, + Cfg: cfg, + Mutex: &sync.Mutex{}, + LastUsedClientIndex: 0, + } +} + +// UpdateClients updates the handlers' client list and configuration +func (h *APIHandlers) UpdateClients(clients []*client.Client, cfg *config.Config) { + h.CliClients = clients + h.Cfg = cfg +} + +// GetClient returns an available client from the pool using round-robin load balancing. +// It checks for quota limits and tries to find an unlocked client for immediate use. +// The modelName parameter is used to check quota status for specific models. +func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (*client.Client, *client.ErrorMessage) { + if len(h.CliClients) == 0 { + return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")} + } + + var cliClient *client.Client + + // Lock the mutex to update the last used client index + h.Mutex.Lock() + startIndex := h.LastUsedClientIndex + if (len(isGenerateContent) > 0 && isGenerateContent[0]) || len(isGenerateContent) == 0 { + currentIndex := (startIndex + 1) % len(h.CliClients) + h.LastUsedClientIndex = currentIndex + } + h.Mutex.Unlock() + + // Reorder the client to start from the last used index + reorderedClients := make([]*client.Client, 0) + for i := 0; i < len(h.CliClients); i++ { + cliClient = h.CliClients[(startIndex+1+i)%len(h.CliClients)] + if cliClient.IsModelQuotaExceeded(modelName) { + log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID()) + cliClient = nil + continue + } + reorderedClients = append(reorderedClients, cliClient) + } + + if len(reorderedClients) == 0 { + return nil, &client.ErrorMessage{StatusCode: 429, Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)} + } + + locked := false + for i := 0; i < len(reorderedClients); i++ { + cliClient = reorderedClients[i] + if cliClient.RequestMutex.TryLock() { + locked = true + break + } + } + if !locked { + cliClient = h.CliClients[0] + cliClient.RequestMutex.Lock() + } + + return cliClient, nil +} + +// GetAlt extracts the 'alt' parameter from the request query string. +// It checks both 'alt' and '$alt' parameters and returns the appropriate value. +func (h *APIHandlers) GetAlt(c *gin.Context) string { + var alt string + var hasAlt bool + alt, hasAlt = c.GetQuery("alt") + if !hasAlt { + alt, _ = c.GetQuery("$alt") + } + if alt == "sse" { + return "" + } + return alt +} diff --git a/internal/api/handlers.go b/internal/api/handlers/openai/openai-handlers.go similarity index 59% rename from internal/api/handlers.go rename to internal/api/handlers/openai/openai-handlers.go index 23bf833f..7623278e 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers/openai/openai-handlers.go @@ -1,50 +1,42 @@ -package api +// Package openai provides HTTP handlers for OpenAI API endpoints. +// This package implements the OpenAI-compatible API interface, including model listing +// and chat completion functionality. It supports both streaming and non-streaming responses, +// and manages a pool of clients to interact with backend services. +// The handlers translate OpenAI API requests to the appropriate backend format and +// convert responses back to OpenAI-compatible format. +package openai import ( "context" "fmt" - "github.com/luispater/CLIProxyAPI/internal/api/translator" + "github.com/luispater/CLIProxyAPI/internal/api/handlers" + "github.com/luispater/CLIProxyAPI/internal/api/translator/openai" "github.com/luispater/CLIProxyAPI/internal/client" - "github.com/luispater/CLIProxyAPI/internal/config" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "net/http" - "sync" "time" "github.com/gin-gonic/gin" ) -var ( - mutex = &sync.Mutex{} - lastUsedClientIndex = 0 -) - -// APIHandlers contains the handlers for API endpoints. +// OpenAIAPIHandlers contains the handlers for OpenAI API endpoints. // It holds a pool of clients to interact with the backend service. -type APIHandlers struct { - cliClients []*client.Client - cfg *config.Config +type OpenAIAPIHandlers struct { + *handlers.APIHandlers } -// NewAPIHandlers creates a new API handlers instance. -// It takes a slice of clients and a debug flag as input. -func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers { - return &APIHandlers{ - cliClients: cliClients, - cfg: cfg, +// NewOpenAIAPIHandlers creates a new OpenAI API handlers instance. +// It takes an APIHandlers instance as input and returns an OpenAIAPIHandlers. +func NewOpenAIAPIHandlers(apiHandlers *handlers.APIHandlers) *OpenAIAPIHandlers { + return &OpenAIAPIHandlers{ + APIHandlers: apiHandlers, } } -// UpdateClients updates the handlers' client list and configuration -func (h *APIHandlers) UpdateClients(clients []*client.Client, cfg *config.Config) { - h.cliClients = clients - h.cfg = cfg -} - // Models handles the /v1/models endpoint. // It returns a hardcoded list of available AI models. -func (h *APIHandlers) Models(c *gin.Context) { +func (h *OpenAIAPIHandlers) Models(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "data": []map[string]any{ { @@ -91,63 +83,15 @@ func (h *APIHandlers) Models(c *gin.Context) { }) } -func (h *APIHandlers) getClient(modelName string, isGenerateContent ...bool) (*client.Client, *client.ErrorMessage) { - if len(h.cliClients) == 0 { - return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")} - } - - var cliClient *client.Client - - // Lock the mutex to update the last used client index - mutex.Lock() - startIndex := lastUsedClientIndex - if (len(isGenerateContent) > 0 && isGenerateContent[0]) || len(isGenerateContent) == 0 { - currentIndex := (startIndex + 1) % len(h.cliClients) - lastUsedClientIndex = currentIndex - } - mutex.Unlock() - - // Reorder the client to start from the last used index - reorderedClients := make([]*client.Client, 0) - for i := 0; i < len(h.cliClients); i++ { - cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)] - if cliClient.IsModelQuotaExceeded(modelName) { - log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID()) - cliClient = nil - continue - } - reorderedClients = append(reorderedClients, cliClient) - } - - if len(reorderedClients) == 0 { - return nil, &client.ErrorMessage{StatusCode: 429, Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)} - } - - locked := false - for i := 0; i < len(reorderedClients); i++ { - cliClient = reorderedClients[i] - if cliClient.RequestMutex.TryLock() { - locked = true - break - } - } - if !locked { - cliClient = h.cliClients[0] - cliClient.RequestMutex.Lock() - } - - return cliClient, nil -} - // ChatCompletions handles the /v1/chat/completions endpoint. // It determines whether the request is for a streaming or non-streaming response // and calls the appropriate handler. -func (h *APIHandlers) ChatCompletions(c *gin.Context) { - rawJson, err := c.GetRawData() +func (h *OpenAIAPIHandlers) ChatCompletions(c *gin.Context) { + rawJSON, err := c.GetRawData() // If data retrieval fails, return a 400 Bad Request error. if err != nil { - c.JSON(http.StatusBadRequest, ErrorResponse{ - Error: ErrorDetail{ + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ Message: fmt.Sprintf("Invalid request: %v", err), Type: "invalid_request_error", }, @@ -156,21 +100,21 @@ func (h *APIHandlers) ChatCompletions(c *gin.Context) { } // Check if the client requested a streaming response. - streamResult := gjson.GetBytes(rawJson, "stream") + streamResult := gjson.GetBytes(rawJSON, "stream") if streamResult.Type == gjson.True { - h.handleStreamingResponse(c, rawJson) + h.handleStreamingResponse(c, rawJSON) } else { - h.handleNonStreamingResponse(c, rawJson) + h.handleNonStreamingResponse(c, rawJSON) } } // handleNonStreamingResponse handles non-streaming chat completion responses. // It selects a client from the pool, sends the request, and aggregates the response // before sending it back to the client. -func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) { +func (h *OpenAIAPIHandlers) handleNonStreamingResponse(c *gin.Context, rawJSON []byte) { c.Header("Content-Type", "application/json") - modelName, systemInstruction, contents, tools := translator.PrepareRequest(rawJson) + modelName, systemInstruction, contents, tools := openai.PrepareRequest(rawJSON) cliCtx, cliCancel := context.WithCancel(context.Background()) var cliClient *client.Client defer func() { @@ -181,7 +125,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) for { var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.getClient(modelName) + cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) _, _ = fmt.Fprint(c.Writer, errorResponse.Error) @@ -197,9 +141,9 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) } - resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, systemInstruction, contents, tools) + resp, err := cliClient.SendMessage(cliCtx, rawJSON, modelName, systemInstruction, contents, tools) if err != nil { - if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { continue } else { c.Status(err.StatusCode) @@ -208,7 +152,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) } break } else { - openAIFormat := translator.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey) + openAIFormat := openai.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey) if openAIFormat != "" { _, _ = c.Writer.Write([]byte(openAIFormat)) } @@ -219,7 +163,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) } // handleStreamingResponse handles streaming responses -func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) { +func (h *OpenAIAPIHandlers) handleStreamingResponse(c *gin.Context, rawJSON []byte) { c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") @@ -228,8 +172,8 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) { // Get the http.Flusher interface to manually flush the response. flusher, ok := c.Writer.(http.Flusher) if !ok { - c.JSON(http.StatusInternalServerError, ErrorResponse{ - Error: ErrorDetail{ + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ Message: "Streaming not supported", Type: "server_error", }, @@ -238,7 +182,7 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) { } // Prepare the request for the backend client. - modelName, systemInstruction, contents, tools := translator.PrepareRequest(rawJson) + modelName, systemInstruction, contents, tools := openai.PrepareRequest(rawJSON) cliCtx, cliCancel := context.WithCancel(context.Background()) var cliClient *client.Client defer func() { @@ -251,7 +195,7 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) { outLoop: for { var errorResponse *client.ErrorMessage - cliClient, errorResponse = h.getClient(modelName) + cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) _, _ = fmt.Fprint(c.Writer, errorResponse.Error) @@ -268,7 +212,7 @@ outLoop: log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) } // Send the message and receive response chunks and errors via channels. - respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, systemInstruction, contents, tools) + respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools) hasFirstResponse := false for { select { @@ -287,19 +231,18 @@ outLoop: flusher.Flush() cliCancel() return - } else { - // Convert the chunk to OpenAI format and send it to the client. - hasFirstResponse = true - openAIFormat := translator.ConvertCliToOpenAI(chunk, time.Now().Unix(), isGlAPIKey) - if openAIFormat != "" { - _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat) - flusher.Flush() - } + } + // Convert the chunk to OpenAI format and send it to the client. + hasFirstResponse = true + openAIFormat := openai.ConvertCliToOpenAI(chunk, time.Now().Unix(), isGlAPIKey) + if openAIFormat != "" { + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat) + flusher.Flush() } // Handle errors from the backend. case err, okError := <-errChan: if okError { - if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject { continue outLoop } else { c.Status(err.StatusCode) diff --git a/internal/api/models.go b/internal/api/models.go deleted file mode 100644 index 71f2bb5a..00000000 --- a/internal/api/models.go +++ /dev/null @@ -1,18 +0,0 @@ -package api - -// ErrorResponse represents a standard error response format for the API. -// It contains a single ErrorDetail field. -type ErrorResponse struct { - Error ErrorDetail `json:"error"` -} - -// ErrorDetail provides specific information about an error that occurred. -// It includes a human-readable message, an error type, and an optional error code. -type ErrorDetail struct { - // A human-readable message providing more details about the error. - Message string `json:"message"` - // The type of error that occurred (e.g., "invalid_request_error"). - Type string `json:"type"` - // A short code identifying the error, if applicable. - Code string `json:"code,omitempty"` -} diff --git a/internal/api/server.go b/internal/api/server.go index b3b6f505..612151b8 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -1,3 +1,7 @@ +// Package api provides the HTTP API server implementation for the CLI Proxy API. +// It includes the main server struct, routing setup, middleware for CORS and authentication, +// and integration with various AI API handlers (OpenAI, Claude, Gemini). +// The server supports hot-reloading of clients and configuration. package api import ( @@ -5,6 +9,11 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/luispater/CLIProxyAPI/internal/api/handlers" + "github.com/luispater/CLIProxyAPI/internal/api/handlers/claude" + "github.com/luispater/CLIProxyAPI/internal/api/handlers/gemini" + "github.com/luispater/CLIProxyAPI/internal/api/handlers/gemini/cli" + "github.com/luispater/CLIProxyAPI/internal/api/handlers/openai" "github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/config" log "github.com/sirupsen/logrus" @@ -17,7 +26,7 @@ import ( type Server struct { engine *gin.Engine server *http.Server - handlers *APIHandlers + handlers *handlers.APIHandlers cfg *config.Config } @@ -29,9 +38,6 @@ func NewServer(cfg *config.Config, cliClients []*client.Client) *Server { gin.SetMode(gin.ReleaseMode) } - // Create handlers - handlers := NewAPIHandlers(cliClients, cfg) - // Create gin engine engine := gin.New() @@ -43,7 +49,7 @@ func NewServer(cfg *config.Config, cliClients []*client.Client) *Server { // Create server instance s := &Server{ engine: engine, - handlers: handlers, + handlers: handlers.NewAPIHandlers(cliClients, cfg), cfg: cfg, } @@ -62,22 +68,27 @@ func NewServer(cfg *config.Config, cliClients []*client.Client) *Server { // setupRoutes configures the API routes for the server. // It defines the endpoints and associates them with their respective handlers. func (s *Server) setupRoutes() { + openaiHandlers := openai.NewOpenAIAPIHandlers(s.handlers) + geminiHandlers := gemini.NewGeminiAPIHandlers(s.handlers) + geminiCLIHandlers := cli.NewGeminiCLIAPIHandlers(s.handlers) + claudeCodeHandlers := claude.NewClaudeCodeAPIHandlers(s.handlers) + // OpenAI compatible API routes v1 := s.engine.Group("/v1") v1.Use(AuthMiddleware(s.cfg)) { - v1.GET("/models", s.handlers.Models) - v1.POST("/chat/completions", s.handlers.ChatCompletions) - v1.POST("/messages", s.handlers.ClaudeMessages) + v1.GET("/models", openaiHandlers.Models) + v1.POST("/chat/completions", openaiHandlers.ChatCompletions) + v1.POST("/messages", claudeCodeHandlers.ClaudeMessages) } // Gemini compatible API routes v1beta := s.engine.Group("/v1beta") v1beta.Use(AuthMiddleware(s.cfg)) { - v1beta.GET("/models", s.handlers.GeminiModels) - v1beta.POST("/models/:action", s.handlers.GeminiHandler) - v1beta.GET("/models/:action", s.handlers.GeminiGetHandler) + v1beta.GET("/models", geminiHandlers.GeminiModels) + v1beta.POST("/models/:action", geminiHandlers.GeminiHandler) + v1beta.GET("/models/:action", geminiHandlers.GeminiGetHandler) } // Root endpoint @@ -91,7 +102,7 @@ func (s *Server) setupRoutes() { }, }) }) - s.engine.POST("/v1internal:method", s.handlers.CLIHandler) + s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler) } @@ -150,7 +161,7 @@ func (s *Server) UpdateClients(clients []*client.Client, cfg *config.Config) { // using API keys. If no API keys are configured, it allows all requests. func AuthMiddleware(cfg *config.Config) gin.HandlerFunc { return func(c *gin.Context) { - if len(cfg.ApiKeys) == 0 { + if len(cfg.APIKeys) == 0 { c.Next() return } @@ -181,9 +192,9 @@ func AuthMiddleware(cfg *config.Config) gin.HandlerFunc { // Find the API key in the in-memory list var foundKey string - for i := range cfg.ApiKeys { - if cfg.ApiKeys[i] == apiKey || cfg.ApiKeys[i] == authHeaderGoogle || cfg.ApiKeys[i] == authHeaderAnthropic || cfg.ApiKeys[i] == apiKeyQuery { - foundKey = cfg.ApiKeys[i] + for i := range cfg.APIKeys { + if cfg.APIKeys[i] == apiKey || cfg.APIKeys[i] == authHeaderGoogle || cfg.APIKeys[i] == authHeaderAnthropic || cfg.APIKeys[i] == apiKeyQuery { + foundKey = cfg.APIKeys[i] break } } diff --git a/internal/api/translator/claude/code/request.go b/internal/api/translator/claude/code/request.go new file mode 100644 index 00000000..4fe924af --- /dev/null +++ b/internal/api/translator/claude/code/request.go @@ -0,0 +1,169 @@ +// Package code provides request translation functionality for Claude API. +// It handles parsing and transforming Claude API requests into the internal client format, +// extracting model information, system instructions, message contents, and tool declarations. +// The package also performs JSON data cleaning and transformation to ensure compatibility +// between Claude API format and the internal client's expected format. +package code + +import ( + "bytes" + "encoding/json" + "github.com/luispater/CLIProxyAPI/internal/client" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "strings" +) + +// PrepareClaudeRequest parses and transforms a Claude API request into internal client format. +// It extracts the model name, system instruction, message contents, and tool declarations +// from the raw JSON request and returns them in the format expected by the internal client. +func PrepareClaudeRequest(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) { + var pathsToDelete []string + root := gjson.ParseBytes(rawJSON) + walk(root, "", "additionalProperties", &pathsToDelete) + walk(root, "", "$schema", &pathsToDelete) + + var err error + for _, p := range pathsToDelete { + rawJSON, err = sjson.DeleteBytes(rawJSON, p) + if err != nil { + continue + } + } + rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1) + + // log.Debug(string(rawJSON)) + modelName := "gemini-2.5-pro" + modelResult := gjson.GetBytes(rawJSON, "model") + if modelResult.Type == gjson.String { + modelName = modelResult.String() + } + + contents := make([]client.Content, 0) + + var systemInstruction *client.Content + + systemResult := gjson.GetBytes(rawJSON, "system") + if systemResult.IsArray() { + systemResults := systemResult.Array() + systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}} + for i := 0; i < len(systemResults); i++ { + systemPromptResult := systemResults[i] + systemTypePromptResult := systemPromptResult.Get("type") + if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" { + systemPrompt := systemPromptResult.Get("text").String() + systemPart := client.Part{Text: systemPrompt} + systemInstruction.Parts = append(systemInstruction.Parts, systemPart) + } + } + if len(systemInstruction.Parts) == 0 { + systemInstruction = nil + } + } + + messagesResult := gjson.GetBytes(rawJSON, "messages") + if messagesResult.IsArray() { + messageResults := messagesResult.Array() + for i := 0; i < len(messageResults); i++ { + messageResult := messageResults[i] + roleResult := messageResult.Get("role") + if roleResult.Type != gjson.String { + continue + } + role := roleResult.String() + if role == "assistant" { + role = "model" + } + clientContent := client.Content{Role: role, Parts: []client.Part{}} + + contentsResult := messageResult.Get("content") + if contentsResult.IsArray() { + contentResults := contentsResult.Array() + for j := 0; j < len(contentResults); j++ { + contentResult := contentResults[j] + contentTypeResult := contentResult.Get("type") + if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" { + prompt := contentResult.Get("text").String() + clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt}) + } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" { + functionName := contentResult.Get("name").String() + functionArgs := contentResult.Get("input").String() + var args map[string]any + if err = json.Unmarshal([]byte(functionArgs), &args); err == nil { + clientContent.Parts = append(clientContent.Parts, client.Part{ + FunctionCall: &client.FunctionCall{ + Name: functionName, + Args: args, + }, + }) + } + } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" { + toolCallID := contentResult.Get("tool_use_id").String() + if toolCallID != "" { + funcName := toolCallID + toolCallIDs := strings.Split(toolCallID, "-") + if len(toolCallIDs) > 1 { + funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-") + } + responseData := contentResult.Get("content").String() + functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}} + clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse}) + } + } + } + contents = append(contents, clientContent) + } else if contentsResult.Type == gjson.String { + prompt := contentsResult.String() + contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}}) + } + } + } + + var tools []client.ToolDeclaration + toolsResult := gjson.GetBytes(rawJSON, "tools") + if toolsResult.IsArray() { + tools = make([]client.ToolDeclaration, 1) + tools[0].FunctionDeclarations = make([]any, 0) + toolsResults := toolsResult.Array() + for i := 0; i < len(toolsResults); i++ { + toolResult := toolsResults[i] + inputSchemaResult := toolResult.Get("input_schema") + if inputSchemaResult.Exists() && inputSchemaResult.IsObject() { + inputSchema := inputSchemaResult.Raw + inputSchema, _ = sjson.Delete(inputSchema, "additionalProperties") + inputSchema, _ = sjson.Delete(inputSchema, "$schema") + + tool, _ := sjson.Delete(toolResult.Raw, "input_schema") + tool, _ = sjson.SetRaw(tool, "parameters", inputSchema) + var toolDeclaration any + if err = json.Unmarshal([]byte(tool), &toolDeclaration); err == nil { + tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration) + } + } + } + } else { + tools = make([]client.ToolDeclaration, 0) + } + + return modelName, systemInstruction, contents, tools +} + +func walk(value gjson.Result, path, field string, pathsToDelete *[]string) { + switch value.Type { + case gjson.JSON: + value.ForEach(func(key, val gjson.Result) bool { + var childPath string + if path == "" { + childPath = key.String() + } else { + childPath = path + "." + key.String() + } + if key.String() == field { + *pathsToDelete = append(*pathsToDelete, childPath) + } + walk(val, childPath, field, pathsToDelete) + return true + }) + case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null: + } +} diff --git a/internal/api/translator/claude/code/response.go b/internal/api/translator/claude/code/response.go new file mode 100644 index 00000000..3ef5fc2b --- /dev/null +++ b/internal/api/translator/claude/code/response.go @@ -0,0 +1,206 @@ +// Package code provides response translation functionality for Claude API. +// This package handles the conversion of backend client responses into Claude-compatible +// Server-Sent Events (SSE) format, implementing a sophisticated state machine that manages +// different response types including text content, thinking processes, and function calls. +// The translation ensures proper sequencing of SSE events and maintains state across +// multiple response chunks to provide a seamless streaming experience. +package code + +import ( + "bytes" + "fmt" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "time" +) + +// ConvertCliToClaude performs sophisticated streaming response format conversion. +// This function implements a complex state machine that translates backend client responses +// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types +// and handles state transitions between content blocks, thinking processes, and function calls. +// +// Response type states: 0=none, 1=content, 2=thinking, 3=function +// The function maintains state across multiple calls to ensure proper SSE event sequencing. +func ConvertCliToClaude(rawJSON []byte, isGlAPIKey, hasFirstResponse bool, responseType, responseIndex *int) string { + // Normalize the response format for different API key types + // Generative Language API keys have a different response structure + if isGlAPIKey { + rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON) + } + + // Track whether tools are being used in this response chunk + usedTool := false + output := "" + + // Initialize the streaming session with a message_start event + // This is only sent for the very first response chunk + if !hasFirstResponse { + output = "event: message_start\n" + + // Create the initial message structure with default values + // This follows the Claude API specification for streaming message initialization + messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}` + + // Override default values with actual response metadata if available + if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { + messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String()) + } + if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() { + messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String()) + } + output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate) + } + + // Process the response parts array from the backend client + // Each part can contain text content, thinking content, or function calls + partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts") + if partsResult.IsArray() { + partResults := partsResult.Array() + for i := 0; i < len(partResults); i++ { + partResult := partResults[i] + + // Extract the different types of content from each part + partTextResult := partResult.Get("text") + functionCallResult := partResult.Get("functionCall") + + // Handle text content (both regular content and thinking) + if partTextResult.Exists() { + // Process thinking content (internal reasoning) + if partResult.Get("thought").Bool() { + // Continue existing thinking block + if *responseType == 2 { + output = output + "event: content_block_delta\n" + data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String()) + output = output + fmt.Sprintf("data: %s\n\n\n", data) + } else { + // Transition from another state to thinking + // First, close any existing content block + if *responseType != 0 { + if *responseType == 2 { + output = output + "event: content_block_delta\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex) + output = output + "\n\n\n" + } + output = output + "event: content_block_stop\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex) + output = output + "\n\n\n" + *responseIndex++ + } + + // Start a new thinking content block + output = output + "event: content_block_start\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, *responseIndex) + output = output + "\n\n\n" + output = output + "event: content_block_delta\n" + data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String()) + output = output + fmt.Sprintf("data: %s\n\n\n", data) + *responseType = 2 // Set state to thinking + } + } else { + // Process regular text content (user-visible output) + // Continue existing text block + if *responseType == 1 { + output = output + "event: content_block_delta\n" + data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String()) + output = output + fmt.Sprintf("data: %s\n\n\n", data) + } else { + // Transition from another state to text content + // First, close any existing content block + if *responseType != 0 { + if *responseType == 2 { + output = output + "event: content_block_delta\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex) + output = output + "\n\n\n" + } + output = output + "event: content_block_stop\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex) + output = output + "\n\n\n" + *responseIndex++ + } + + // Start a new text content block + output = output + "event: content_block_start\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, *responseIndex) + output = output + "\n\n\n" + output = output + "event: content_block_delta\n" + data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String()) + output = output + fmt.Sprintf("data: %s\n\n\n", data) + *responseType = 1 // Set state to content + } + } + } else if functionCallResult.Exists() { + // Handle function/tool calls from the AI model + // This processes tool usage requests and formats them for Claude API compatibility + usedTool = true + fcName := functionCallResult.Get("name").String() + + // Handle state transitions when switching to function calls + // Close any existing function call block first + if *responseType == 3 { + output = output + "event: content_block_stop\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex) + output = output + "\n\n\n" + *responseIndex++ + *responseType = 0 + } + + // Special handling for thinking state transition + if *responseType == 2 { + output = output + "event: content_block_delta\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex) + output = output + "\n\n\n" + } + + // Close any other existing content block + if *responseType != 0 { + output = output + "event: content_block_stop\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex) + output = output + "\n\n\n" + *responseIndex++ + } + + // Start a new tool use content block + // This creates the structure for a function call in Claude format + output = output + "event: content_block_start\n" + + // Create the tool use block with unique ID and function details + data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, *responseIndex) + data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + data, _ = sjson.Set(data, "content_block.name", fcName) + output = output + fmt.Sprintf("data: %s\n\n\n", data) + + if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { + output = output + "event: content_block_delta\n" + data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, *responseIndex), "delta.partial_json", fcArgsResult.Raw) + output = output + fmt.Sprintf("data: %s\n\n\n", data) + } + *responseType = 3 + } + } + } + + usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata") + if usageResult.Exists() && bytes.Contains(rawJSON, []byte(`"finishReason"`)) { + if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { + output = output + "event: content_block_stop\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex) + output = output + "\n\n\n" + + output = output + "event: message_delta\n" + output = output + `data: ` + + template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` + if usedTool { + template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` + } + + thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() + template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount) + template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int()) + + output = output + template + "\n\n\n" + } + } + + return output +} diff --git a/internal/api/translator/gemini/cli/request.go b/internal/api/translator/gemini/cli/request.go new file mode 100644 index 00000000..460820d0 --- /dev/null +++ b/internal/api/translator/gemini/cli/request.go @@ -0,0 +1,185 @@ +// Package cli provides request translation functionality for Gemini CLI API. +// It handles the conversion and formatting of CLI tool responses, specifically +// transforming between different JSON formats to ensure proper conversation flow +// and API compatibility. The package focuses on intelligently grouping function +// calls with their corresponding responses, converting from linear format to +// grouped format where function calls and responses are properly associated. +package cli + +import ( + "encoding/json" + "fmt" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// FunctionCallGroup represents a group of function calls and their responses +type FunctionCallGroup struct { + ModelContent map[string]interface{} + FunctionCalls []gjson.Result + ResponsesNeeded int +} + +// FixCLIToolResponse performs sophisticated tool response format conversion and grouping. +// This function transforms the CLI tool response format by intelligently grouping function calls +// with their corresponding responses, ensuring proper conversation flow and API compatibility. +// It converts from a linear format (1.json) to a grouped format (2.json) where function calls +// and their responses are properly associated and structured. +func FixCLIToolResponse(input string) (string, error) { + // Parse the input JSON to extract the conversation structure + parsed := gjson.Parse(input) + + // Extract the contents array which contains the conversation messages + contents := parsed.Get("request.contents") + if !contents.Exists() { + // log.Debugf(input) + return input, fmt.Errorf("contents not found in input") + } + + // Initialize data structures for processing and grouping + var newContents []interface{} // Final processed contents array + var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses + var collectedResponses []gjson.Result // Standalone responses to be matched + + // Process each content object in the conversation + // This iterates through messages and groups function calls with their responses + contents.ForEach(func(key, value gjson.Result) bool { + role := value.Get("role").String() + parts := value.Get("parts") + + // Check if this content has function responses + var responsePartsInThisContent []gjson.Result + parts.ForEach(func(_, part gjson.Result) bool { + if part.Get("functionResponse").Exists() { + responsePartsInThisContent = append(responsePartsInThisContent, part) + } + return true + }) + + // If this content has function responses, collect them + if len(responsePartsInThisContent) > 0 { + collectedResponses = append(collectedResponses, responsePartsInThisContent...) + + // Check if any pending groups can be satisfied + for i := len(pendingGroups) - 1; i >= 0; i-- { + group := pendingGroups[i] + if len(collectedResponses) >= group.ResponsesNeeded { + // Take the needed responses for this group + groupResponses := collectedResponses[:group.ResponsesNeeded] + collectedResponses = collectedResponses[group.ResponsesNeeded:] + + // Create merged function response content + var responseParts []interface{} + for _, response := range groupResponses { + var responseMap map[string]interface{} + errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap) + if errUnmarshal != nil { + log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal) + continue + } + responseParts = append(responseParts, responseMap) + } + + if len(responseParts) > 0 { + functionResponseContent := map[string]interface{}{ + "parts": responseParts, + "role": "function", + } + newContents = append(newContents, functionResponseContent) + } + + // Remove this group as it's been satisfied + pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...) + break + } + } + + return true // Skip adding this content, responses are merged + } + + // If this is a model with function calls, create a new group + if role == "model" { + var functionCallsInThisModel []gjson.Result + parts.ForEach(func(_, part gjson.Result) bool { + if part.Get("functionCall").Exists() { + functionCallsInThisModel = append(functionCallsInThisModel, part) + } + return true + }) + + if len(functionCallsInThisModel) > 0 { + // Add the model content + var contentMap map[string]interface{} + errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) + if errUnmarshal != nil { + log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal) + return true + } + newContents = append(newContents, contentMap) + + // Create a new group for tracking responses + group := &FunctionCallGroup{ + ModelContent: contentMap, + FunctionCalls: functionCallsInThisModel, + ResponsesNeeded: len(functionCallsInThisModel), + } + pendingGroups = append(pendingGroups, group) + } else { + // Regular model content without function calls + var contentMap map[string]interface{} + errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) + if errUnmarshal != nil { + log.Warnf("failed to unmarshal content: %v\n", errUnmarshal) + return true + } + newContents = append(newContents, contentMap) + } + } else { + // Non-model content (user, etc.) + var contentMap map[string]interface{} + errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) + if errUnmarshal != nil { + log.Warnf("failed to unmarshal content: %v\n", errUnmarshal) + return true + } + newContents = append(newContents, contentMap) + } + + return true + }) + + // Handle any remaining pending groups with remaining responses + for _, group := range pendingGroups { + if len(collectedResponses) >= group.ResponsesNeeded { + groupResponses := collectedResponses[:group.ResponsesNeeded] + collectedResponses = collectedResponses[group.ResponsesNeeded:] + + var responseParts []interface{} + for _, response := range groupResponses { + var responseMap map[string]interface{} + errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap) + if errUnmarshal != nil { + log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal) + continue + } + responseParts = append(responseParts, responseMap) + } + + if len(responseParts) > 0 { + functionResponseContent := map[string]interface{}{ + "parts": responseParts, + "role": "function", + } + newContents = append(newContents, functionResponseContent) + } + } + } + + // Update the original JSON with the new contents + result := input + newContentsJSON, _ := json.Marshal(newContents) + result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON)) + + return result, nil +} diff --git a/internal/api/translator/mime-type.go b/internal/api/translator/mime-type.go index 95938ff1..c467b183 100644 --- a/internal/api/translator/mime-type.go +++ b/internal/api/translator/mime-type.go @@ -1,3 +1,6 @@ +// Package translator provides data translation and format conversion utilities +// for the CLI Proxy API. It includes MIME type mappings and other translation +// functions used across different API endpoints. package translator // MimeTypes is a comprehensive map of file extensions to their corresponding MIME types. diff --git a/internal/api/translator/openai/request.go b/internal/api/translator/openai/request.go new file mode 100644 index 00000000..7251ca9e --- /dev/null +++ b/internal/api/translator/openai/request.go @@ -0,0 +1,226 @@ +// Package openai provides request translation functionality for OpenAI API. +// It handles the conversion of OpenAI-compatible request formats to the internal +// format expected by the backend client, including parsing messages, roles, +// content types (text, image, file), and tool calls. +package openai + +import ( + "encoding/json" + "github.com/luispater/CLIProxyAPI/internal/api/translator" + "strings" + + "github.com/luispater/CLIProxyAPI/internal/client" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" +) + +// PrepareRequest translates a raw JSON request from an OpenAI-compatible format +// to the internal format expected by the backend client. It parses messages, +// roles, content types (text, image, file), and tool calls. +func PrepareRequest(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) { + // Extract the model name from the request, defaulting to "gemini-2.5-pro". + modelName := "gemini-2.5-pro" + modelResult := gjson.GetBytes(rawJSON, "model") + if modelResult.Type == gjson.String { + modelName = modelResult.String() + } + + // Initialize data structures for processing conversation messages + // contents: stores the processed conversation history + // systemInstruction: stores system-level instructions separate from conversation + contents := make([]client.Content, 0) + var systemInstruction *client.Content + messagesResult := gjson.GetBytes(rawJSON, "messages") + + // Pre-process tool responses to create a lookup map + // This first pass collects all tool responses so they can be matched with their corresponding calls + toolItems := make(map[string]*client.FunctionResponse) + if messagesResult.IsArray() { + messagesResults := messagesResult.Array() + for i := 0; i < len(messagesResults); i++ { + messageResult := messagesResults[i] + roleResult := messageResult.Get("role") + if roleResult.Type != gjson.String { + continue + } + contentResult := messageResult.Get("content") + + // Extract tool responses for later matching with function calls + if roleResult.String() == "tool" { + toolCallID := messageResult.Get("tool_call_id").String() + if toolCallID != "" { + var responseData string + // Handle both string and object-based tool response formats + if contentResult.Type == gjson.String { + responseData = contentResult.String() + } else if contentResult.IsObject() && contentResult.Get("type").String() == "text" { + responseData = contentResult.Get("text").String() + } + + // Clean up tool call ID by removing timestamp suffix + // This normalizes IDs for consistent matching between calls and responses + toolCallIDs := strings.Split(toolCallID, "-") + strings.Join(toolCallIDs, "-") + newToolCallID := strings.Join(toolCallIDs[:len(toolCallIDs)-1], "-") + + // Create function response object with normalized ID and response data + functionResponse := client.FunctionResponse{Name: newToolCallID, Response: map[string]interface{}{"result": responseData}} + toolItems[toolCallID] = &functionResponse + } + } + } + } + + if messagesResult.IsArray() { + messagesResults := messagesResult.Array() + for i := 0; i < len(messagesResults); i++ { + messageResult := messagesResults[i] + roleResult := messageResult.Get("role") + contentResult := messageResult.Get("content") + if roleResult.Type != gjson.String { + continue + } + + switch roleResult.String() { + // System messages are converted to a user message followed by a model's acknowledgment. + case "system": + if contentResult.Type == gjson.String { + systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}} + } else if contentResult.IsObject() { + // Handle object-based system messages. + if contentResult.Get("type").String() == "text" { + systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}} + } + } + // User messages can contain simple text or a multi-part body. + case "user": + if contentResult.Type == gjson.String { + contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}}) + } else if contentResult.IsArray() { + // Handle multi-part user messages (text, images, files). + contentItemResults := contentResult.Array() + parts := make([]client.Part, 0) + for j := 0; j < len(contentItemResults); j++ { + contentItemResult := contentItemResults[j] + contentTypeResult := contentItemResult.Get("type") + switch contentTypeResult.String() { + case "text": + parts = append(parts, client.Part{Text: contentItemResult.Get("text").String()}) + case "image_url": + // Parse data URI for images. + imageURL := contentItemResult.Get("image_url.url").String() + if len(imageURL) > 5 { + imageURLs := strings.SplitN(imageURL[5:], ";", 2) + if len(imageURLs) == 2 && len(imageURLs[1]) > 7 { + parts = append(parts, client.Part{InlineData: &client.InlineData{ + MimeType: imageURLs[0], + Data: imageURLs[1][7:], + }}) + } + } + case "file": + // Handle file attachments by determining MIME type from extension. + filename := contentItemResult.Get("file.filename").String() + fileData := contentItemResult.Get("file.file_data").String() + ext := "" + if split := strings.Split(filename, "."); len(split) > 1 { + ext = split[len(split)-1] + } + if mimeType, ok := translator.MimeTypes[ext]; ok { + parts = append(parts, client.Part{InlineData: &client.InlineData{ + MimeType: mimeType, + Data: fileData, + }}) + } else { + log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j) + } + } + } + contents = append(contents, client.Content{Role: "user", Parts: parts}) + } + // Assistant messages can contain text responses or tool calls + // In the internal format, assistant messages are converted to "model" role + case "assistant": + if contentResult.Type == gjson.String { + // Simple text response from the assistant + contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}}) + } else if !contentResult.Exists() || contentResult.Type == gjson.Null { + // Handle complex tool calls made by the assistant + // This processes function calls and matches them with their responses + functionIDs := make([]string, 0) + toolCallsResult := messageResult.Get("tool_calls") + if toolCallsResult.IsArray() { + parts := make([]client.Part, 0) + tcsResult := toolCallsResult.Array() + + // Process each tool call in the assistant's message + for j := 0; j < len(tcsResult); j++ { + tcResult := tcsResult[j] + + // Extract function call details + functionID := tcResult.Get("id").String() + functionIDs = append(functionIDs, functionID) + + functionName := tcResult.Get("function.name").String() + functionArgs := tcResult.Get("function.arguments").String() + + // Parse function arguments from JSON string to map + var args map[string]any + if err := json.Unmarshal([]byte(functionArgs), &args); err == nil { + parts = append(parts, client.Part{ + FunctionCall: &client.FunctionCall{ + Name: functionName, + Args: args, + }, + }) + } + } + + // Add the model's function calls to the conversation + if len(parts) > 0 { + contents = append(contents, client.Content{ + Role: "model", Parts: parts, + }) + + // Create a separate tool response message with the collected responses + // This matches function calls with their corresponding responses + toolParts := make([]client.Part, 0) + for _, functionID := range functionIDs { + if functionResponse, ok := toolItems[functionID]; ok { + toolParts = append(toolParts, client.Part{FunctionResponse: functionResponse}) + } + } + // Add the tool responses as a separate message in the conversation + contents = append(contents, client.Content{Role: "tool", Parts: toolParts}) + } + } + } + } + } + } + + // Translate the tool declarations from the request. + var tools []client.ToolDeclaration + toolsResult := gjson.GetBytes(rawJSON, "tools") + if toolsResult.IsArray() { + tools = make([]client.ToolDeclaration, 1) + tools[0].FunctionDeclarations = make([]any, 0) + toolsResults := toolsResult.Array() + for i := 0; i < len(toolsResults); i++ { + toolResult := toolsResults[i] + if toolResult.Get("type").String() == "function" { + functionTypeResult := toolResult.Get("function") + if functionTypeResult.Exists() && functionTypeResult.IsObject() { + var functionDeclaration any + if err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration); err == nil { + tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration) + } + } + } + } + } else { + tools = make([]client.ToolDeclaration, 0) + } + + return modelName, systemInstruction, contents, tools +} diff --git a/internal/api/translator/openai/response.go b/internal/api/translator/openai/response.go new file mode 100644 index 00000000..67757e29 --- /dev/null +++ b/internal/api/translator/openai/response.go @@ -0,0 +1,197 @@ +// Package openai provides response translation functionality for converting between +// different API response formats and OpenAI-compatible formats. It handles both +// streaming and non-streaming responses, transforming backend client responses +// into OpenAI Server-Sent Events (SSE) format and standard JSON response formats. +// The package supports content translation, function calls, usage metadata, +// and various response attributes while maintaining compatibility with OpenAI API +// specifications. +package openai + +import ( + "fmt" + "time" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// ConvertCliToOpenAI translates a single chunk of a streaming response from the +// backend client format to the OpenAI Server-Sent Events (SSE) format. +// It returns an empty string if the chunk contains no useful data. +func ConvertCliToOpenAI(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string { + if isGlAPIKey { + rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON) + } + + // Initialize the OpenAI SSE template. + template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` + + // Extract and set the model version. + if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { + template, _ = sjson.Set(template, "model", modelVersionResult.String()) + } + + // Extract and set the creation timestamp. + if createTimeResult := gjson.GetBytes(rawJSON, "response.createTime"); createTimeResult.Exists() { + t, err := time.Parse(time.RFC3339Nano, createTimeResult.String()) + if err == nil { + unixTimestamp = t.Unix() + } + template, _ = sjson.Set(template, "created", unixTimestamp) + } else { + template, _ = sjson.Set(template, "created", unixTimestamp) + } + + // Extract and set the response ID. + if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() { + template, _ = sjson.Set(template, "id", responseIDResult.String()) + } + + // Extract and set the finish reason. + if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() { + template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String()) + template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String()) + } + + // Extract and set usage metadata (token counts). + if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() { + if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { + template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int()) + } + if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() { + template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int()) + } + promptTokenCount := usageResult.Get("promptTokenCount").Int() + thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() + template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) + if thoughtsTokenCount > 0 { + template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) + } + } + + // Process the main content part of the response. + partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts") + if partsResult.IsArray() { + partResults := partsResult.Array() + for i := 0; i < len(partResults); i++ { + partResult := partResults[i] + partTextResult := partResult.Get("text") + functionCallResult := partResult.Get("functionCall") + + if partTextResult.Exists() { + // Handle text content, distinguishing between regular content and reasoning/thoughts. + if partResult.Get("thought").Bool() { + template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String()) + } else { + template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String()) + } + template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") + } else if functionCallResult.Exists() { + // Handle function call content. + toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls") + if !toolCallsResult.Exists() || !toolCallsResult.IsArray() { + template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`) + } + + functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}` + fcName := functionCallResult.Get("name").String() + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName) + if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw) + } + template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") + template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallTemplate) + } + } + } + + return template +} + +// ConvertCliToOpenAINonStream aggregates response from the backend client +// convert a single, non-streaming OpenAI-compatible JSON response. +func ConvertCliToOpenAINonStream(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string { + if isGlAPIKey { + rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON) + } + template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` + if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() { + template, _ = sjson.Set(template, "model", modelVersionResult.String()) + } + + if createTimeResult := gjson.GetBytes(rawJSON, "response.createTime"); createTimeResult.Exists() { + t, err := time.Parse(time.RFC3339Nano, createTimeResult.String()) + if err == nil { + unixTimestamp = t.Unix() + } + template, _ = sjson.Set(template, "created", unixTimestamp) + } else { + template, _ = sjson.Set(template, "created", unixTimestamp) + } + + if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() { + template, _ = sjson.Set(template, "id", responseIDResult.String()) + } + + if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() { + template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String()) + template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String()) + } + + if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() { + if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { + template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int()) + } + if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() { + template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int()) + } + promptTokenCount := usageResult.Get("promptTokenCount").Int() + thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() + template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) + if thoughtsTokenCount > 0 { + template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) + } + } + + // Process the main content part of the response. + partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts") + if partsResult.IsArray() { + partsResults := partsResult.Array() + for i := 0; i < len(partsResults); i++ { + partResult := partsResults[i] + partTextResult := partResult.Get("text") + functionCallResult := partResult.Get("functionCall") + + if partTextResult.Exists() { + // Append text content, distinguishing between regular content and reasoning. + if partResult.Get("thought").Bool() { + template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String()) + } else { + template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String()) + } + template, _ = sjson.Set(template, "choices.0.message.role", "assistant") + } else if functionCallResult.Exists() { + // Append function call content to the tool_calls array. + toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls") + if !toolCallsResult.Exists() || !toolCallsResult.IsArray() { + template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`) + } + functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}` + fcName := functionCallResult.Get("name").String() + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName) + if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw) + } + template, _ = sjson.Set(template, "choices.0.message.role", "assistant") + template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate) + } else { + // If no usable content is found, return an empty string. + return "" + } + } + } + + return template +} diff --git a/internal/api/translator/request.go b/internal/api/translator/request.go deleted file mode 100644 index bd6720ba..00000000 --- a/internal/api/translator/request.go +++ /dev/null @@ -1,545 +0,0 @@ -package translator - -import ( - "bytes" - "encoding/json" - "fmt" - "github.com/tidwall/sjson" - "strings" - - "github.com/luispater/CLIProxyAPI/internal/client" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" -) - -// PrepareRequest translates a raw JSON request from an OpenAI-compatible format -// to the internal format expected by the backend client. It parses messages, -// roles, content types (text, image, file), and tool calls. -func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) { - // Extract the model name from the request, defaulting to "gemini-2.5-pro". - modelName := "gemini-2.5-pro" - modelResult := gjson.GetBytes(rawJson, "model") - if modelResult.Type == gjson.String { - modelName = modelResult.String() - } - - // Initialize data structures for processing conversation messages - // contents: stores the processed conversation history - // systemInstruction: stores system-level instructions separate from conversation - contents := make([]client.Content, 0) - var systemInstruction *client.Content - messagesResult := gjson.GetBytes(rawJson, "messages") - - // Pre-process tool responses to create a lookup map - // This first pass collects all tool responses so they can be matched with their corresponding calls - toolItems := make(map[string]*client.FunctionResponse) - if messagesResult.IsArray() { - messagesResults := messagesResult.Array() - for i := 0; i < len(messagesResults); i++ { - messageResult := messagesResults[i] - roleResult := messageResult.Get("role") - if roleResult.Type != gjson.String { - continue - } - contentResult := messageResult.Get("content") - - // Extract tool responses for later matching with function calls - if roleResult.String() == "tool" { - toolCallID := messageResult.Get("tool_call_id").String() - if toolCallID != "" { - var responseData string - // Handle both string and object-based tool response formats - if contentResult.Type == gjson.String { - responseData = contentResult.String() - } else if contentResult.IsObject() && contentResult.Get("type").String() == "text" { - responseData = contentResult.Get("text").String() - } - - // Clean up tool call ID by removing timestamp suffix - // This normalizes IDs for consistent matching between calls and responses - toolCallIDs := strings.Split(toolCallID, "-") - strings.Join(toolCallIDs, "-") - newToolCallID := strings.Join(toolCallIDs[:len(toolCallIDs)-1], "-") - - // Create function response object with normalized ID and response data - functionResponse := client.FunctionResponse{Name: newToolCallID, Response: map[string]interface{}{"result": responseData}} - toolItems[toolCallID] = &functionResponse - } - } - } - } - - if messagesResult.IsArray() { - messagesResults := messagesResult.Array() - for i := 0; i < len(messagesResults); i++ { - messageResult := messagesResults[i] - roleResult := messageResult.Get("role") - contentResult := messageResult.Get("content") - if roleResult.Type != gjson.String { - continue - } - - switch roleResult.String() { - // System messages are converted to a user message followed by a model's acknowledgment. - case "system": - if contentResult.Type == gjson.String { - systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}} - } else if contentResult.IsObject() { - // Handle object-based system messages. - if contentResult.Get("type").String() == "text" { - systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}} - } - } - // User messages can contain simple text or a multi-part body. - case "user": - if contentResult.Type == gjson.String { - contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}}) - } else if contentResult.IsArray() { - // Handle multi-part user messages (text, images, files). - contentItemResults := contentResult.Array() - parts := make([]client.Part, 0) - for j := 0; j < len(contentItemResults); j++ { - contentItemResult := contentItemResults[j] - contentTypeResult := contentItemResult.Get("type") - switch contentTypeResult.String() { - case "text": - parts = append(parts, client.Part{Text: contentItemResult.Get("text").String()}) - case "image_url": - // Parse data URI for images. - imageURL := contentItemResult.Get("image_url.url").String() - if len(imageURL) > 5 { - imageURLs := strings.SplitN(imageURL[5:], ";", 2) - if len(imageURLs) == 2 && len(imageURLs[1]) > 7 { - parts = append(parts, client.Part{InlineData: &client.InlineData{ - MimeType: imageURLs[0], - Data: imageURLs[1][7:], - }}) - } - } - case "file": - // Handle file attachments by determining MIME type from extension. - filename := contentItemResult.Get("file.filename").String() - fileData := contentItemResult.Get("file.file_data").String() - ext := "" - if split := strings.Split(filename, "."); len(split) > 1 { - ext = split[len(split)-1] - } - if mimeType, ok := MimeTypes[ext]; ok { - parts = append(parts, client.Part{InlineData: &client.InlineData{ - MimeType: mimeType, - Data: fileData, - }}) - } else { - log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j) - } - } - } - contents = append(contents, client.Content{Role: "user", Parts: parts}) - } - // Assistant messages can contain text responses or tool calls - // In the internal format, assistant messages are converted to "model" role - case "assistant": - if contentResult.Type == gjson.String { - // Simple text response from the assistant - contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}}) - } else if !contentResult.Exists() || contentResult.Type == gjson.Null { - // Handle complex tool calls made by the assistant - // This processes function calls and matches them with their responses - functionIDs := make([]string, 0) - toolCallsResult := messageResult.Get("tool_calls") - if toolCallsResult.IsArray() { - parts := make([]client.Part, 0) - tcsResult := toolCallsResult.Array() - - // Process each tool call in the assistant's message - for j := 0; j < len(tcsResult); j++ { - tcResult := tcsResult[j] - - // Extract function call details - functionID := tcResult.Get("id").String() - functionIDs = append(functionIDs, functionID) - - functionName := tcResult.Get("function.name").String() - functionArgs := tcResult.Get("function.arguments").String() - - // Parse function arguments from JSON string to map - var args map[string]any - if err := json.Unmarshal([]byte(functionArgs), &args); err == nil { - parts = append(parts, client.Part{ - FunctionCall: &client.FunctionCall{ - Name: functionName, - Args: args, - }, - }) - } - } - - // Add the model's function calls to the conversation - if len(parts) > 0 { - contents = append(contents, client.Content{ - Role: "model", Parts: parts, - }) - - // Create a separate tool response message with the collected responses - // This matches function calls with their corresponding responses - toolParts := make([]client.Part, 0) - for _, functionID := range functionIDs { - if functionResponse, ok := toolItems[functionID]; ok { - toolParts = append(toolParts, client.Part{FunctionResponse: functionResponse}) - } - } - // Add the tool responses as a separate message in the conversation - contents = append(contents, client.Content{Role: "tool", Parts: toolParts}) - } - } - } - } - } - } - - // Translate the tool declarations from the request. - var tools []client.ToolDeclaration - toolsResult := gjson.GetBytes(rawJson, "tools") - if toolsResult.IsArray() { - tools = make([]client.ToolDeclaration, 1) - tools[0].FunctionDeclarations = make([]any, 0) - toolsResults := toolsResult.Array() - for i := 0; i < len(toolsResults); i++ { - toolResult := toolsResults[i] - if toolResult.Get("type").String() == "function" { - functionTypeResult := toolResult.Get("function") - if functionTypeResult.Exists() && functionTypeResult.IsObject() { - var functionDeclaration any - if err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration); err == nil { - tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration) - } - } - } - } - } else { - tools = make([]client.ToolDeclaration, 0) - } - - return modelName, systemInstruction, contents, tools -} - -// FunctionCallGroup represents a group of function calls and their responses -type FunctionCallGroup struct { - ModelContent map[string]interface{} - FunctionCalls []gjson.Result - ResponsesNeeded int -} - -// FixCLIToolResponse performs sophisticated tool response format conversion and grouping. -// This function transforms the CLI tool response format by intelligently grouping function calls -// with their corresponding responses, ensuring proper conversation flow and API compatibility. -// It converts from a linear format (1.json) to a grouped format (2.json) where function calls -// and their responses are properly associated and structured. -func FixCLIToolResponse(input string) (string, error) { - // Parse the input JSON to extract the conversation structure - parsed := gjson.Parse(input) - - // Extract the contents array which contains the conversation messages - contents := parsed.Get("request.contents") - if !contents.Exists() { - // log.Debugf(input) - return input, fmt.Errorf("contents not found in input") - } - - // Initialize data structures for processing and grouping - var newContents []interface{} // Final processed contents array - var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses - var collectedResponses []gjson.Result // Standalone responses to be matched - - // Process each content object in the conversation - // This iterates through messages and groups function calls with their responses - contents.ForEach(func(key, value gjson.Result) bool { - role := value.Get("role").String() - parts := value.Get("parts") - - // Check if this content has function responses - var responsePartsInThisContent []gjson.Result - parts.ForEach(func(_, part gjson.Result) bool { - if part.Get("functionResponse").Exists() { - responsePartsInThisContent = append(responsePartsInThisContent, part) - } - return true - }) - - // If this content has function responses, collect them - if len(responsePartsInThisContent) > 0 { - collectedResponses = append(collectedResponses, responsePartsInThisContent...) - - // Check if any pending groups can be satisfied - for i := len(pendingGroups) - 1; i >= 0; i-- { - group := pendingGroups[i] - if len(collectedResponses) >= group.ResponsesNeeded { - // Take the needed responses for this group - groupResponses := collectedResponses[:group.ResponsesNeeded] - collectedResponses = collectedResponses[group.ResponsesNeeded:] - - // Create merged function response content - var responseParts []interface{} - for _, response := range groupResponses { - var responseMap map[string]interface{} - errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap) - if errUnmarshal != nil { - log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal) - continue - } - responseParts = append(responseParts, responseMap) - } - - if len(responseParts) > 0 { - functionResponseContent := map[string]interface{}{ - "parts": responseParts, - "role": "function", - } - newContents = append(newContents, functionResponseContent) - } - - // Remove this group as it's been satisfied - pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...) - break - } - } - - return true // Skip adding this content, responses are merged - } - - // If this is a model with function calls, create a new group - if role == "model" { - var functionCallsInThisModel []gjson.Result - parts.ForEach(func(_, part gjson.Result) bool { - if part.Get("functionCall").Exists() { - functionCallsInThisModel = append(functionCallsInThisModel, part) - } - return true - }) - - if len(functionCallsInThisModel) > 0 { - // Add the model content - var contentMap map[string]interface{} - errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) - if errUnmarshal != nil { - log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal) - return true - } - newContents = append(newContents, contentMap) - - // Create a new group for tracking responses - group := &FunctionCallGroup{ - ModelContent: contentMap, - FunctionCalls: functionCallsInThisModel, - ResponsesNeeded: len(functionCallsInThisModel), - } - pendingGroups = append(pendingGroups, group) - } else { - // Regular model content without function calls - var contentMap map[string]interface{} - errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) - if errUnmarshal != nil { - log.Warnf("failed to unmarshal content: %v\n", errUnmarshal) - return true - } - newContents = append(newContents, contentMap) - } - } else { - // Non-model content (user, etc.) - var contentMap map[string]interface{} - errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) - if errUnmarshal != nil { - log.Warnf("failed to unmarshal content: %v\n", errUnmarshal) - return true - } - newContents = append(newContents, contentMap) - } - - return true - }) - - // Handle any remaining pending groups with remaining responses - for _, group := range pendingGroups { - if len(collectedResponses) >= group.ResponsesNeeded { - groupResponses := collectedResponses[:group.ResponsesNeeded] - collectedResponses = collectedResponses[group.ResponsesNeeded:] - - var responseParts []interface{} - for _, response := range groupResponses { - var responseMap map[string]interface{} - errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap) - if errUnmarshal != nil { - log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal) - continue - } - responseParts = append(responseParts, responseMap) - } - - if len(responseParts) > 0 { - functionResponseContent := map[string]interface{}{ - "parts": responseParts, - "role": "function", - } - newContents = append(newContents, functionResponseContent) - } - } - } - - // Update the original JSON with the new contents - result := input - newContentsJSON, _ := json.Marshal(newContents) - result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON)) - - return result, nil -} - -func PrepareClaudeRequest(rawJson []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) { - var pathsToDelete []string - root := gjson.ParseBytes(rawJson) - walk(root, "", "additionalProperties", &pathsToDelete) - walk(root, "", "$schema", &pathsToDelete) - - var err error - for _, p := range pathsToDelete { - rawJson, err = sjson.DeleteBytes(rawJson, p) - if err != nil { - continue - } - } - rawJson = bytes.Replace(rawJson, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1) - - // log.Debug(string(rawJson)) - modelName := "gemini-2.5-pro" - modelResult := gjson.GetBytes(rawJson, "model") - if modelResult.Type == gjson.String { - modelName = modelResult.String() - } - - contents := make([]client.Content, 0) - - var systemInstruction *client.Content - - systemResult := gjson.GetBytes(rawJson, "system") - if systemResult.IsArray() { - systemResults := systemResult.Array() - systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}} - for i := 0; i < len(systemResults); i++ { - systemPromptResult := systemResults[i] - systemTypePromptResult := systemPromptResult.Get("type") - if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" { - systemPrompt := systemPromptResult.Get("text").String() - systemPart := client.Part{Text: systemPrompt} - systemInstruction.Parts = append(systemInstruction.Parts, systemPart) - } - } - if len(systemInstruction.Parts) == 0 { - systemInstruction = nil - } - } - - messagesResult := gjson.GetBytes(rawJson, "messages") - if messagesResult.IsArray() { - messageResults := messagesResult.Array() - for i := 0; i < len(messageResults); i++ { - messageResult := messageResults[i] - roleResult := messageResult.Get("role") - if roleResult.Type != gjson.String { - continue - } - role := roleResult.String() - if role == "assistant" { - role = "model" - } - clientContent := client.Content{Role: role, Parts: []client.Part{}} - - contentsResult := messageResult.Get("content") - if contentsResult.IsArray() { - contentResults := contentsResult.Array() - for j := 0; j < len(contentResults); j++ { - contentResult := contentResults[j] - contentTypeResult := contentResult.Get("type") - if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" { - prompt := contentResult.Get("text").String() - clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt}) - } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" { - functionName := contentResult.Get("name").String() - functionArgs := contentResult.Get("input").String() - var args map[string]any - if err = json.Unmarshal([]byte(functionArgs), &args); err == nil { - clientContent.Parts = append(clientContent.Parts, client.Part{ - FunctionCall: &client.FunctionCall{ - Name: functionName, - Args: args, - }, - }) - } - } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" { - toolCallID := contentResult.Get("tool_use_id").String() - if toolCallID != "" { - funcName := toolCallID - toolCallIDs := strings.Split(toolCallID, "-") - if len(toolCallIDs) > 1 { - funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-") - } - responseData := contentResult.Get("content").String() - functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}} - clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse}) - } - } - } - contents = append(contents, clientContent) - } else if contentsResult.Type == gjson.String { - prompt := contentsResult.String() - contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}}) - } - } - } - - var tools []client.ToolDeclaration - toolsResult := gjson.GetBytes(rawJson, "tools") - if toolsResult.IsArray() { - tools = make([]client.ToolDeclaration, 1) - tools[0].FunctionDeclarations = make([]any, 0) - toolsResults := toolsResult.Array() - for i := 0; i < len(toolsResults); i++ { - toolResult := toolsResults[i] - inputSchemaResult := toolResult.Get("input_schema") - if inputSchemaResult.Exists() && inputSchemaResult.IsObject() { - inputSchema := inputSchemaResult.Raw - inputSchema, _ = sjson.Delete(inputSchema, "additionalProperties") - inputSchema, _ = sjson.Delete(inputSchema, "$schema") - - tool, _ := sjson.Delete(toolResult.Raw, "input_schema") - tool, _ = sjson.SetRaw(tool, "parameters", inputSchema) - var toolDeclaration any - if err = json.Unmarshal([]byte(tool), &toolDeclaration); err == nil { - tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration) - } - } - } - } else { - tools = make([]client.ToolDeclaration, 0) - } - - return modelName, systemInstruction, contents, tools -} - -func walk(value gjson.Result, path, field string, pathsToDelete *[]string) { - switch value.Type { - case gjson.JSON: - value.ForEach(func(key, val gjson.Result) bool { - var childPath string - if path == "" { - childPath = key.String() - } else { - childPath = path + "." + key.String() - } - if key.String() == field { - *pathsToDelete = append(*pathsToDelete, childPath) - } - walk(val, childPath, field, pathsToDelete) - return true - }) - case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null: - } -} diff --git a/internal/api/translator/response.go b/internal/api/translator/response.go deleted file mode 100644 index 0cc45da6..00000000 --- a/internal/api/translator/response.go +++ /dev/null @@ -1,382 +0,0 @@ -package translator - -import ( - "bytes" - "fmt" - "time" - - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" -) - -// ConvertCliToOpenAI translates a single chunk of a streaming response from the -// backend client format to the OpenAI Server-Sent Events (SSE) format. -// It returns an empty string if the chunk contains no useful data. -func ConvertCliToOpenAI(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) string { - if isGlAPIKey { - rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson) - } - - // Initialize the OpenAI SSE template. - template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` - - // Extract and set the model version. - if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() { - template, _ = sjson.Set(template, "model", modelVersionResult.String()) - } - - // Extract and set the creation timestamp. - if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() { - t, err := time.Parse(time.RFC3339Nano, createTimeResult.String()) - if err == nil { - unixTimestamp = t.Unix() - } - template, _ = sjson.Set(template, "created", unixTimestamp) - } else { - template, _ = sjson.Set(template, "created", unixTimestamp) - } - - // Extract and set the response ID. - if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() { - template, _ = sjson.Set(template, "id", responseIdResult.String()) - } - - // Extract and set the finish reason. - if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() { - template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String()) - template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String()) - } - - // Extract and set usage metadata (token counts). - if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() { - if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { - template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int()) - } - if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() { - template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int()) - } - promptTokenCount := usageResult.Get("promptTokenCount").Int() - thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() - template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) - if thoughtsTokenCount > 0 { - template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) - } - } - - // Process the main content part of the response. - partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts") - if partsResult.IsArray() { - partResults := partsResult.Array() - for i := 0; i < len(partResults); i++ { - partResult := partResults[i] - partTextResult := partResult.Get("text") - functionCallResult := partResult.Get("functionCall") - - if partTextResult.Exists() { - // Handle text content, distinguishing between regular content and reasoning/thoughts. - if partResult.Get("thought").Bool() { - template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String()) - } else { - template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String()) - } - template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") - } else if functionCallResult.Exists() { - // Handle function call content. - toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls") - if !toolCallsResult.Exists() || !toolCallsResult.IsArray() { - template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`) - } - - functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}` - fcName := functionCallResult.Get("name").String() - functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) - functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName) - if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { - functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw) - } - template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") - template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallTemplate) - } - } - } - - return template -} - -// ConvertCliToOpenAINonStream aggregates response from the backend client -// convert a single, non-streaming OpenAI-compatible JSON response. -func ConvertCliToOpenAINonStream(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) string { - if isGlAPIKey { - rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson) - } - template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` - if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() { - template, _ = sjson.Set(template, "model", modelVersionResult.String()) - } - - if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() { - t, err := time.Parse(time.RFC3339Nano, createTimeResult.String()) - if err == nil { - unixTimestamp = t.Unix() - } - template, _ = sjson.Set(template, "created", unixTimestamp) - } else { - template, _ = sjson.Set(template, "created", unixTimestamp) - } - - if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() { - template, _ = sjson.Set(template, "id", responseIdResult.String()) - } - - if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() { - template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String()) - template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String()) - } - - if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() { - if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { - template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int()) - } - if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() { - template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int()) - } - promptTokenCount := usageResult.Get("promptTokenCount").Int() - thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() - template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) - if thoughtsTokenCount > 0 { - template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) - } - } - - // Process the main content part of the response. - partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts") - if partsResult.IsArray() { - partsResults := partsResult.Array() - for i := 0; i < len(partsResults); i++ { - partResult := partsResults[i] - partTextResult := partResult.Get("text") - functionCallResult := partResult.Get("functionCall") - - if partTextResult.Exists() { - // Append text content, distinguishing between regular content and reasoning. - if partResult.Get("thought").Bool() { - template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String()) - } else { - template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String()) - } - template, _ = sjson.Set(template, "choices.0.message.role", "assistant") - } else if functionCallResult.Exists() { - // Append function call content to the tool_calls array. - toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls") - if !toolCallsResult.Exists() || !toolCallsResult.IsArray() { - template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`) - } - functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}` - fcName := functionCallResult.Get("name").String() - functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) - functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName) - if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { - functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw) - } - template, _ = sjson.Set(template, "choices.0.message.role", "assistant") - template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate) - } else { - // If no usable content is found, return an empty string. - return "" - } - } - } - - return template -} - -// ConvertCliToClaude performs sophisticated streaming response format conversion. -// This function implements a complex state machine that translates backend client responses -// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types -// and handles state transitions between content blocks, thinking processes, and function calls. -// -// Response type states: 0=none, 1=content, 2=thinking, 3=function -// The function maintains state across multiple calls to ensure proper SSE event sequencing. -func ConvertCliToClaude(rawJson []byte, isGlAPIKey, hasFirstResponse bool, responseType, responseIndex *int) string { - // Normalize the response format for different API key types - // Generative Language API keys have a different response structure - if isGlAPIKey { - rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson) - } - - // Track whether tools are being used in this response chunk - usedTool := false - output := "" - - // Initialize the streaming session with a message_start event - // This is only sent for the very first response chunk - if !hasFirstResponse { - output = "event: message_start\n" - - // Create the initial message structure with default values - // This follows the Claude API specification for streaming message initialization - messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}` - - // Override default values with actual response metadata if available - if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() { - messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String()) - } - if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() { - messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIdResult.String()) - } - output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate) - } - - // Process the response parts array from the backend client - // Each part can contain text content, thinking content, or function calls - partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts") - if partsResult.IsArray() { - partResults := partsResult.Array() - for i := 0; i < len(partResults); i++ { - partResult := partResults[i] - - // Extract the different types of content from each part - partTextResult := partResult.Get("text") - functionCallResult := partResult.Get("functionCall") - - // Handle text content (both regular content and thinking) - if partTextResult.Exists() { - // Process thinking content (internal reasoning) - if partResult.Get("thought").Bool() { - // Continue existing thinking block - if *responseType == 2 { - output = output + "event: content_block_delta\n" - data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String()) - output = output + fmt.Sprintf("data: %s\n\n\n", data) - } else { - // Transition from another state to thinking - // First, close any existing content block - if *responseType != 0 { - if *responseType == 2 { - output = output + "event: content_block_delta\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex) - output = output + "\n\n\n" - } - output = output + "event: content_block_stop\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex) - output = output + "\n\n\n" - *responseIndex++ - } - - // Start a new thinking content block - output = output + "event: content_block_start\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, *responseIndex) - output = output + "\n\n\n" - output = output + "event: content_block_delta\n" - data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String()) - output = output + fmt.Sprintf("data: %s\n\n\n", data) - *responseType = 2 // Set state to thinking - } - } else { - // Process regular text content (user-visible output) - // Continue existing text block - if *responseType == 1 { - output = output + "event: content_block_delta\n" - data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String()) - output = output + fmt.Sprintf("data: %s\n\n\n", data) - } else { - // Transition from another state to text content - // First, close any existing content block - if *responseType != 0 { - if *responseType == 2 { - output = output + "event: content_block_delta\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex) - output = output + "\n\n\n" - } - output = output + "event: content_block_stop\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex) - output = output + "\n\n\n" - *responseIndex++ - } - - // Start a new text content block - output = output + "event: content_block_start\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, *responseIndex) - output = output + "\n\n\n" - output = output + "event: content_block_delta\n" - data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String()) - output = output + fmt.Sprintf("data: %s\n\n\n", data) - *responseType = 1 // Set state to content - } - } - } else if functionCallResult.Exists() { - // Handle function/tool calls from the AI model - // This processes tool usage requests and formats them for Claude API compatibility - usedTool = true - fcName := functionCallResult.Get("name").String() - - // Handle state transitions when switching to function calls - // Close any existing function call block first - if *responseType == 3 { - output = output + "event: content_block_stop\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex) - output = output + "\n\n\n" - *responseIndex++ - *responseType = 0 - } - - // Special handling for thinking state transition - if *responseType == 2 { - output = output + "event: content_block_delta\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex) - output = output + "\n\n\n" - } - - // Close any other existing content block - if *responseType != 0 { - output = output + "event: content_block_stop\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex) - output = output + "\n\n\n" - *responseIndex++ - } - - // Start a new tool use content block - // This creates the structure for a function call in Claude format - output = output + "event: content_block_start\n" - - // Create the tool use block with unique ID and function details - data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, *responseIndex) - data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) - data, _ = sjson.Set(data, "content_block.name", fcName) - output = output + fmt.Sprintf("data: %s\n\n\n", data) - - if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { - output = output + "event: content_block_delta\n" - data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, *responseIndex), "delta.partial_json", fcArgsResult.Raw) - output = output + fmt.Sprintf("data: %s\n\n\n", data) - } - *responseType = 3 - } - } - } - - usageResult := gjson.GetBytes(rawJson, "response.usageMetadata") - if usageResult.Exists() && bytes.Contains(rawJson, []byte(`"finishReason"`)) { - if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { - output = output + "event: content_block_stop\n" - output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex) - output = output + "\n\n\n" - - output = output + "event: message_delta\n" - output = output + `data: ` - - template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` - if usedTool { - template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` - } - - thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() - template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount) - template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int()) - - output = output + template + "\n\n\n" - } - } - - return output -} diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 89d6c3dd..8a67c3c9 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -1,3 +1,6 @@ +// Package auth provides OAuth2 authentication functionality for Google Cloud APIs. +// It handles the complete OAuth2 flow including token storage, web-based authentication, +// proxy support, and automatic token refresh. The package supports both SOCKS5 and HTTP/HTTPS proxies. package auth import ( @@ -39,7 +42,7 @@ var ( // initiating a new web-based OAuth flow if necessary, and refreshing tokens. func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) { // Configure proxy settings for the HTTP client if a proxy URL is provided. - proxyURL, err := url.Parse(cfg.ProxyUrl) + proxyURL, err := url.Parse(cfg.ProxyURL) if err == nil { var transport *http.Transport if proxyURL.Scheme == "socks5" { diff --git a/internal/client/client.go b/internal/client/client.go index 3a9e92b1..146e49b8 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -1,3 +1,7 @@ +// Package client provides HTTP client functionality for interacting with Google Cloud AI APIs. +// It handles OAuth2 authentication, token management, request/response processing, +// streaming communication, quota management, and automatic model fallback. +// The package supports both direct API key authentication and OAuth2 flows. package client import ( @@ -29,7 +33,7 @@ const ( pluginVersion = "0.1.9" glEndPoint = "https://generativelanguage.googleapis.com" - glApiVersion = "v1beta" + glAPIVersion = "v1beta" ) var ( @@ -64,30 +68,37 @@ func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Confi } } +// SetProjectID updates the project ID for the client's token storage. func (c *Client) SetProjectID(projectID string) { c.tokenStorage.ProjectID = projectID } +// SetIsAuto configures whether the client should operate in automatic mode. func (c *Client) SetIsAuto(auto bool) { c.tokenStorage.Auto = auto } +// SetIsChecked sets the checked status for the client's token storage. func (c *Client) SetIsChecked(checked bool) { c.tokenStorage.Checked = checked } +// IsChecked returns whether the client's token storage has been checked. func (c *Client) IsChecked() bool { return c.tokenStorage.Checked } +// IsAuto returns whether the client is operating in automatic mode. func (c *Client) IsAuto() bool { return c.tokenStorage.Auto } +// GetEmail returns the email address associated with the client's token storage. func (c *Client) GetEmail() string { return c.tokenStorage.Email } +// GetProjectID returns the Google Cloud project ID from the client's token storage. func (c *Client) GetProjectID() string { if c.tokenStorage != nil { return c.tokenStorage.ProjectID @@ -95,6 +106,7 @@ func (c *Client) GetProjectID() string { return "" } +// GetGenerativeLanguageAPIKey returns the generative language API key if configured. func (c *Client) GetGenerativeLanguageAPIKey() string { return c.glAPIKey } @@ -267,10 +279,10 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface } else { if endpoint == "countTokens" { modelResult := gjson.GetBytes(jsonBody, "model") - url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint) + url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glAPIVersion, modelResult.String(), endpoint) } else { modelResult := gjson.GetBytes(jsonBody, "model") - url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint) + url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glAPIVersion, modelResult.String(), endpoint) if alt == "" && stream { url = url + "?alt=sse" } else { @@ -333,7 +345,7 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface } // SendMessage handles a single conversational turn, including tool calls. -func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) { +func (c *Client) SendMessage(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) { request := GenerateContentRequest{ Contents: contents, GenerationConfig: GenerationConfig{ @@ -357,7 +369,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, // log.Debug(string(byteRequestBody)) - reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort") + reasoningEffortResult := gjson.GetBytes(rawJSON, "reasoning_effort") if reasoningEffortResult.String() == "none" { byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts") byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0) @@ -373,17 +385,17 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) } - temperatureResult := gjson.GetBytes(rawJson, "temperature") + temperatureResult := gjson.GetBytes(rawJSON, "temperature") if temperatureResult.Exists() && temperatureResult.Type == gjson.Number { byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num) } - topPResult := gjson.GetBytes(rawJson, "top_p") + topPResult := gjson.GetBytes(rawJSON, "top_p") if topPResult.Exists() && topPResult.Type == gjson.Number { byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num) } - topKResult := gjson.GetBytes(rawJson, "top_k") + topKResult := gjson.GetBytes(rawJSON, "top_k") if topKResult.Exists() && topKResult.Type == gjson.Number { byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num) } @@ -430,7 +442,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, // This function implements a sophisticated streaming system that supports tool calls, reasoning modes, // quota management, and automatic model fallback. It returns two channels for asynchronous communication: // one for streaming response data and another for error handling. -func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration, includeThoughts ...bool) (<-chan []byte, <-chan *ErrorMessage) { +func (c *Client) SendMessageStream(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration, includeThoughts ...bool) (<-chan []byte, <-chan *ErrorMessage) { // Define the data prefix used in Server-Sent Events streaming format dataTag := []byte("data: ") @@ -486,7 +498,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st // Parse and configure reasoning effort levels from the original request // This maps Claude-style reasoning effort parameters to Gemini's thinking budget system - reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort") + reasoningEffortResult := gjson.GetBytes(rawJSON, "reasoning_effort") if reasoningEffortResult.String() == "none" { // Disable thinking entirely for fastest responses byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts") @@ -510,21 +522,21 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st // Configure temperature parameter for response randomness control // Temperature affects the creativity vs consistency trade-off in responses - temperatureResult := gjson.GetBytes(rawJson, "temperature") + temperatureResult := gjson.GetBytes(rawJSON, "temperature") if temperatureResult.Exists() && temperatureResult.Type == gjson.Number { byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num) } // Configure top-p parameter for nucleus sampling // Controls the cumulative probability threshold for token selection - topPResult := gjson.GetBytes(rawJson, "top_p") + topPResult := gjson.GetBytes(rawJSON, "top_p") if topPResult.Exists() && topPResult.Type == gjson.Number { byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num) } // Configure top-k parameter for limiting token candidates // Restricts the model to consider only the top K most likely tokens - topKResult := gjson.GetBytes(rawJson, "top_k") + topKResult := gjson.GetBytes(rawJSON, "top_k") if topKResult.Exists() && topKResult.Type == gjson.Number { byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num) } @@ -608,8 +620,8 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st } // SendRawTokenCount handles a token count. -func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt string) ([]byte, *ErrorMessage) { - modelResult := gjson.GetBytes(rawJson, "model") +func (c *Client) SendRawTokenCount(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) { + modelResult := gjson.GetBytes(rawJSON, "model") model := modelResult.String() modelName := model for { @@ -618,7 +630,7 @@ func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt stri modelName = c.getPreviewModel(model) if modelName != "" { log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) - rawJson, _ = sjson.SetBytes(rawJson, "model", modelName) + rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName) continue } } @@ -628,7 +640,7 @@ func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt stri } } - respBody, err := c.APIRequest(ctx, "countTokens", rawJson, alt, false) + respBody, err := c.APIRequest(ctx, "countTokens", rawJSON, alt, false) if err != nil { if err.StatusCode == 429 { now := time.Now() @@ -649,12 +661,12 @@ func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt stri } // SendRawMessage handles a single conversational turn, including tool calls. -func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string) ([]byte, *ErrorMessage) { +func (c *Client) SendRawMessage(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) { if c.glAPIKey == "" { - rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID()) + rawJSON, _ = sjson.SetBytes(rawJSON, "project", c.GetProjectID()) } - modelResult := gjson.GetBytes(rawJson, "model") + modelResult := gjson.GetBytes(rawJSON, "model") model := modelResult.String() modelName := model for { @@ -663,7 +675,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string) modelName = c.getPreviewModel(model) if modelName != "" { log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) - rawJson, _ = sjson.SetBytes(rawJson, "model", modelName) + rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName) continue } } @@ -673,7 +685,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string) } } - respBody, err := c.APIRequest(ctx, "generateContent", rawJson, alt, false) + respBody, err := c.APIRequest(ctx, "generateContent", rawJSON, alt, false) if err != nil { if err.StatusCode == 429 { now := time.Now() @@ -694,7 +706,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string) } // SendRawMessageStream handles a single conversational turn, including tool calls. -func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) { +func (c *Client) SendRawMessageStream(ctx context.Context, rawJSON []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) { dataTag := []byte("data: ") errChan := make(chan *ErrorMessage) dataChan := make(chan []byte) @@ -703,10 +715,10 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt s defer close(dataChan) if c.glAPIKey == "" { - rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID()) + rawJSON, _ = sjson.SetBytes(rawJSON, "project", c.GetProjectID()) } - modelResult := gjson.GetBytes(rawJson, "model") + modelResult := gjson.GetBytes(rawJSON, "model") model := modelResult.String() modelName := model var stream io.ReadCloser @@ -716,7 +728,7 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt s modelName = c.getPreviewModel(model) if modelName != "" { log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) - rawJson, _ = sjson.SetBytes(rawJson, "model", modelName) + rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName) continue } } @@ -727,7 +739,7 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt s return } var err *ErrorMessage - stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJson, alt, true) + stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJSON, alt, true) if err != nil { if err.StatusCode == 429 { now := time.Now() @@ -774,6 +786,8 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt s return dataChan, errChan } +// isModelQuotaExceeded checks if the specified model has exceeded its quota +// within the last 30 minutes. func (c *Client) isModelQuotaExceeded(model string) bool { if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey { duration := time.Now().Sub(*lastExceededTime) @@ -785,6 +799,8 @@ func (c *Client) isModelQuotaExceeded(model string) bool { return false } +// getPreviewModel returns an available preview model for the given base model, +// or an empty string if no preview models are available or all are quota exceeded. func (c *Client) getPreviewModel(model string) string { if models, hasKey := previewModels[model]; hasKey { for i := 0; i < len(models); i++ { @@ -796,6 +812,8 @@ func (c *Client) getPreviewModel(model string) string { return "" } +// IsModelQuotaExceeded returns true if the specified model has exceeded its quota +// and no fallback options are available. func (c *Client) IsModelQuotaExceeded(model string) bool { if c.isModelQuotaExceeded(model) { if c.cfg.QuotaExceeded.SwitchPreviewModel { @@ -824,20 +842,20 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) { if err != nil { // If a 403 Forbidden error occurs, it likely means the API is not enabled. if err.StatusCode == 403 { - errJson := err.Error.Error() + errJSON := err.Error.Error() // Check for a specific error code and extract the activation URL. - if gjson.Get(errJson, "error.code").Int() == 403 { - activationUrl := gjson.Get(errJson, "error.details.0.metadata.activationUrl").String() - if activationUrl != "" { + if gjson.Get(errJSON, "error.code").Int() == 403 { + activationURL := gjson.Get(errJSON, "error.details.0.metadata.activationUrl").String() + if activationURL != "" { log.Warnf( "\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s", - activationUrl, + activationURL, os.Args[0], c.tokenStorage.ProjectID, ) } } - log.Warnf("\n\nPlease copy this message and create an issue.\n\n%s\n\n", errJson) + log.Warnf("\n\nPlease copy this message and create an issue.\n\n%s\n\n", errJSON) return false, nil } return false, err.Error diff --git a/internal/cmd/login.go b/internal/cmd/login.go index 0c261998..5d98e160 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -1,3 +1,6 @@ +// Package cmd provides command-line interface functionality for the CLI Proxy API. +// It implements the main application commands including login/authentication +// and server startup, handling the complete user onboarding and service lifecycle. package cmd import ( diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 37b0118c..b946bbe1 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -1,3 +1,8 @@ +// Package cmd provides the main service execution functionality for the CLIProxyAPI. +// It contains the core logic for starting and managing the API proxy service, +// including authentication client management, server initialization, and graceful shutdown handling. +// The package handles loading authentication tokens, creating client pools, starting the API server, +// and monitoring configuration changes through file watchers. package cmd import ( diff --git a/internal/config/config.go b/internal/config/config.go index 534c565b..0e8368a3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,3 +1,7 @@ +// Package config provides configuration management for the CLI Proxy API server. +// It handles loading and parsing YAML configuration files, and provides structured +// access to application settings including server port, authentication directory, +// debug settings, proxy configuration, and API keys. package config import ( @@ -14,17 +18,19 @@ type Config struct { AuthDir string `yaml:"auth-dir"` // Debug enables or disables debug-level logging and other debug features. Debug bool `yaml:"debug"` - // ProxyUrl is the URL of an optional proxy server to use for outbound requests. - ProxyUrl string `yaml:"proxy-url"` - // ApiKeys is a list of keys for authenticating clients to this proxy server. - ApiKeys []string `yaml:"api-keys"` + // ProxyURL is the URL of an optional proxy server to use for outbound requests. + ProxyURL string `yaml:"proxy-url"` + // APIKeys is a list of keys for authenticating clients to this proxy server. + APIKeys []string `yaml:"api-keys"` // QuotaExceeded defines the behavior when a quota is exceeded. - QuotaExceeded ConfigQuotaExceeded `yaml:"quota-exceeded"` + QuotaExceeded QuotaExceeded `yaml:"quota-exceeded"` // GlAPIKey is the API key for the generative language API. GlAPIKey []string `yaml:"generative-language-api-key"` } -type ConfigQuotaExceeded struct { +// QuotaExceeded defines the behavior when API quota limits are exceeded. +// It provides configuration options for automatic failover mechanisms. +type QuotaExceeded struct { // SwitchProject indicates whether to automatically switch to another project when a quota is exceeded. SwitchProject bool `yaml:"switch-project"` // SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded. diff --git a/internal/util/proxy.go b/internal/util/proxy.go index f48c6fc3..dbf80a02 100644 --- a/internal/util/proxy.go +++ b/internal/util/proxy.go @@ -1,3 +1,6 @@ +// Package util provides utility functions for the CLI Proxy API server. +// It includes helper functions for proxy configuration, HTTP client setup, +// and other common operations used across the application. package util import ( @@ -9,9 +12,12 @@ import ( "net/url" ) +// SetProxy configures the provided HTTP client with proxy settings from the configuration. +// It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport +// to route requests through the configured proxy server. func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error) { var transport *http.Transport - proxyURL, errParse := url.Parse(cfg.ProxyUrl) + proxyURL, errParse := url.Parse(cfg.ProxyURL) if errParse == nil { if proxyURL.Scheme == "socks5" { username := proxyURL.User.Username() diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 0efbaa5f..68240140 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -1,3 +1,7 @@ +// Package watcher provides file system monitoring functionality for the CLI Proxy API. +// It watches configuration files and authentication directories for changes, +// automatically reloading clients and configuration when files are modified. +// The package handles cross-platform file system events and supports hot-reloading. package watcher import ( @@ -156,11 +160,11 @@ func (w *Watcher) reloadConfig() { if oldConfig.Debug != newConfig.Debug { log.Debugf(" debug: %t -> %t", oldConfig.Debug, newConfig.Debug) } - if oldConfig.ProxyUrl != newConfig.ProxyUrl { - log.Debugf(" proxy-url: %s -> %s", oldConfig.ProxyUrl, newConfig.ProxyUrl) + if oldConfig.ProxyURL != newConfig.ProxyURL { + log.Debugf(" proxy-url: %s -> %s", oldConfig.ProxyURL, newConfig.ProxyURL) } - if len(oldConfig.ApiKeys) != len(newConfig.ApiKeys) { - log.Debugf(" api-keys count: %d -> %d", len(oldConfig.ApiKeys), len(newConfig.ApiKeys)) + if len(oldConfig.APIKeys) != len(newConfig.APIKeys) { + log.Debugf(" api-keys count: %d -> %d", len(oldConfig.APIKeys), len(newConfig.APIKeys)) } if len(oldConfig.GlAPIKey) != len(newConfig.GlAPIKey) { log.Debugf(" generative-language-api-key count: %d -> %d", len(oldConfig.GlAPIKey), len(newConfig.GlAPIKey)) @@ -248,7 +252,7 @@ func (w *Watcher) reloadClients() { log.Debugf("auth directory scan complete - found %d .json files, %d successful authentications", authFileCount, successfulAuthCount) // Add clients for Generative Language API keys if configured - glApiKeyCount := 0 + glAPIKeyCount := 0 if len(cfg.GlAPIKey) > 0 { log.Debugf("processing %d Generative Language API keys", len(cfg.GlAPIKey)) for i := 0; i < len(cfg.GlAPIKey); i++ { @@ -261,9 +265,9 @@ func (w *Watcher) reloadClients() { log.Debugf(" initializing with Generative Language API key %d...", i+1) cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i]) newClients = append(newClients, cliClient) - glApiKeyCount++ + glAPIKeyCount++ } - log.Debugf("successfully initialized %d Generative Language API key clients", glApiKeyCount) + log.Debugf("successfully initialized %d Generative Language API key clients", glAPIKeyCount) } // Update the client list @@ -272,7 +276,7 @@ func (w *Watcher) reloadClients() { w.clientsMutex.Unlock() log.Infof("client reload complete - old: %d clients, new: %d clients (%d auth files + %d GL API keys)", - oldClientCount, len(newClients), successfulAuthCount, glApiKeyCount) + oldClientCount, len(newClients), successfulAuthCount, glAPIKeyCount) // Trigger the callback to update the server if w.reloadCallback != nil {