From b3607d398196bb6ff323fec864424dd7603b7255 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Fri, 11 Jul 2025 04:01:45 +0800 Subject: [PATCH] Add Gemini-compatible API and improve error handling - Introduced a new Gemini-compatible API with routes under `/v1beta`. - Added `GeminiHandler` to manage `generateContent` and `streamGenerateContent` actions. - Enhanced `AuthMiddleware` to support `X-Goog-Api-Key` header. - Improved client metadata handling and added conditional project ID updates in API calls. - Updated logging to debug raw API request payloads for better traceability. --- internal/api/gemini-handlers.go | 215 ++++++++++++++++++++++++++++++++ internal/api/server.go | 13 +- internal/client/client.go | 9 +- 3 files changed, 233 insertions(+), 4 deletions(-) create mode 100644 internal/api/gemini-handlers.go diff --git a/internal/api/gemini-handlers.go b/internal/api/gemini-handlers.go new file mode 100644 index 00000000..c6fa8226 --- /dev/null +++ b/internal/api/gemini-handlers.go @@ -0,0 +1,215 @@ +package api + +import ( + "context" + "fmt" + "github.com/gin-gonic/gin" + "github.com/luispater/CLIProxyAPI/internal/client" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "net/http" + "strings" + "time" +) + +func (h *APIHandlers) GeminiHandler(c *gin.Context) { + var person struct { + Action string `uri:"action" binding:"required"` + } + if err := c.ShouldBindUri(&person); err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse{ + Error: ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + action := strings.Split(person.Action, ":") + if len(action) != 2 { + c.JSON(http.StatusNotFound, ErrorResponse{ + Error: ErrorDetail{ + Message: fmt.Sprintf("%s not found.", c.Request.URL.Path), + Type: "invalid_request_error", + }, + }) + return + } + + modelName := action[0] + method := action[1] + rawJson, _ := c.GetRawData() + rawJson, _ = sjson.SetBytes(rawJson, "model", []byte(modelName)) + + if method == "generateContent" { + h.geminiGenerateContent(c, rawJson) + } else if method == "streamGenerateContent" { + h.geminiStreamGenerateContent(c, rawJson) + } +} + +func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte) { + // Get the http.Flusher interface to manually flush the response. + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, ErrorResponse{ + Error: ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + modelResult := gjson.GetBytes(rawJson, "model") + modelName := modelResult.String() + + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient *client.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + if cliClient != nil { + cliClient.RequestMutex.Unlock() + } + }() + +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.getClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + flusher.Flush() + cliCancel() + return + } + + template := `{"project":"","request":{},"model":""}` + template, _ = sjson.SetRaw(template, "request", string(rawJson)) + template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String()) + template, _ = sjson.Delete(template, "request.model") + systemInstructionResult := gjson.Get(template, "request.system_instruction") + if systemInstructionResult.Exists() { + template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw) + template, _ = sjson.Delete(template, "request.system_instruction") + } + rawJson = []byte(template) + + if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use generative language API Key: %s", glAPIKey) + } else { + log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) + } + + // Send the message and receive response chunks and errors via channels. + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson) + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("Client disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + cliCancel() + return + } else { + if cliClient.GetGenerativeLanguageAPIKey() == "" { + responseResult := gjson.GetBytes(chunk, "response") + if responseResult.Exists() { + chunk = []byte(responseResult.Raw) + } + } + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write(chunk) + _, _ = c.Writer.Write([]byte("\n\n")) + flusher.Flush() + } + // Handle errors from the backend. + case err, okError := <-errChan: + if okError { + if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + continue outLoop + } else { + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + flusher.Flush() + cliCancel() + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + } + } + } +} + +func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) { + c.Header("Content-Type", "application/json") + + modelResult := gjson.GetBytes(rawJson, "model") + modelName := modelResult.String() + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient *client.Client + defer func() { + if cliClient != nil { + cliClient.RequestMutex.Unlock() + } + }() + + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.getClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + cliCancel() + return + } + + template := `{"project":"","request":{},"model":""}` + template, _ = sjson.SetRaw(template, "request", string(rawJson)) + template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String()) + template, _ = sjson.Delete(template, "request.model") + systemInstructionResult := gjson.Get(template, "request.system_instruction") + if systemInstructionResult.Exists() { + template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw) + template, _ = sjson.Delete(template, "request.system_instruction") + } + rawJson = []byte(template) + + if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use generative language API Key: %s", glAPIKey) + } else { + log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) + } + resp, err := cliClient.SendRawMessage(cliCtx, rawJson) + if err != nil { + if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + continue + } else { + c.Status(err.StatusCode) + _, _ = c.Writer.Write([]byte(err.Error.Error())) + cliCancel() + } + break + } else { + if cliClient.GetGenerativeLanguageAPIKey() == "" { + responseResult := gjson.GetBytes(resp, "response") + if responseResult.Exists() { + resp = []byte(responseResult.Raw) + } + } + _, _ = c.Writer.Write(resp) + cliCancel() + break + } + } +} diff --git a/internal/api/server.go b/internal/api/server.go index 52b32006..eb360e10 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -70,6 +70,14 @@ func (s *Server) setupRoutes() { v1.POST("/chat/completions", s.handlers.ChatCompletions) } + // Gemini compatible API routes + v1beta := s.engine.Group("/v1beta") + v1beta.Use(AuthMiddleware(s.cfg)) + { + v1beta.GET("/models", s.handlers.Models) + v1beta.POST("/models/:action", s.handlers.GeminiHandler) + } + // Root endpoint s.engine.GET("/", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ @@ -140,7 +148,8 @@ func AuthMiddleware(cfg *config.Config) gin.HandlerFunc { // Get the Authorization header authHeader := c.GetHeader("Authorization") - if authHeader == "" { + authHeaderGoogle := c.GetHeader("X-Goog-Api-Key") + if authHeader == "" && authHeaderGoogle == "" { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "Missing API key", }) @@ -159,7 +168,7 @@ func AuthMiddleware(cfg *config.Config) gin.HandlerFunc { // Find the API key in the in-memory list var foundKey string for i := range cfg.ApiKeys { - if cfg.ApiKeys[i] == apiKey { + if cfg.ApiKeys[i] == apiKey || cfg.ApiKeys[i] == authHeaderGoogle { foundKey = cfg.ApiKeys[i] break } diff --git a/internal/client/client.go b/internal/client/client.go index 9b874b30..1d5dca10 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -311,6 +311,7 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface } }() bodyBytes, _ := io.ReadAll(resp.Body) + // log.Debug(string(jsonBody)) return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))} } @@ -534,7 +535,9 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st // SendRawMessage handles a single conversational turn, including tool calls. func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *ErrorMessage) { - rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID()) + if c.glAPIKey == "" { + rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID()) + } modelResult := gjson.GetBytes(rawJson, "model") model := modelResult.String() @@ -584,7 +587,9 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-ch defer close(errChan) defer close(dataChan) - rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID()) + if c.glAPIKey == "" { + rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID()) + } modelResult := gjson.GetBytes(rawJson, "model") model := modelResult.String()