From 273e1d9cbe6713040d0fcc6f8c58f64c7d4da7c5 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Thu, 10 Jul 2025 05:16:54 +0800 Subject: [PATCH] Add system instruction support and enhance internal API handlers - Introduced `SystemInstruction` field in `PrepareRequest` and `GenerateContentRequest` for better message parsing. - Updated `SendMessage` and `SendMessageStream` to handle system instructions in client API calls. - Enhanced error handling and manual flushing logic in response flows. - Added new internal API endpoints `/v1internal:generateContent` and `/v1internal:streamGenerateContent`. - Improved proxy handling and transport logic in HTTP client initialization. --- internal/api/handlers.go | 327 +++++++++++++++++++++++++++-- internal/api/server.go | 2 + internal/api/translator/request.go | 11 +- internal/client/client.go | 130 +++++++++++- internal/client/models.go | 7 +- 5 files changed, 442 insertions(+), 35 deletions(-) diff --git a/internal/api/handlers.go b/internal/api/handlers.go index edf8ddc7..93404459 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -1,6 +1,7 @@ package api import ( + "bytes" "context" "fmt" "github.com/luispater/CLIProxyAPI/internal/api/translator" @@ -8,7 +9,12 @@ import ( "github.com/luispater/CLIProxyAPI/internal/config" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "golang.org/x/net/proxy" + "io" + "net" "net/http" + "net/url" "sync" "time" @@ -196,19 +202,7 @@ func (h *APIHandlers) ChatCompletions(c *gin.Context) { func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) { c.Header("Content-Type", "application/json") - // Handle streaming manually - flusher, ok := c.Writer.(http.Flusher) - if !ok { - c.JSON(http.StatusInternalServerError, ErrorResponse{ - Error: ErrorDetail{ - Message: "Streaming not supported", - Type: "server_error", - }, - }) - return - } - - modelName, contents, tools := translator.PrepareRequest(rawJson) + modelName, systemInstruction, contents, tools := translator.PrepareRequest(rawJson) cliCtx, cliCancel := context.WithCancel(context.Background()) var cliClient *client.Client defer func() { @@ -239,8 +233,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) if len(reorderedClients) == 0 { c.Status(429) - _, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)) - flusher.Flush() + _, _ = c.Writer.Write([]byte(fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))) cliCancel() return } @@ -266,22 +259,20 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) } - resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, contents, tools) + resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, systemInstruction, contents, tools) if err != nil { if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { continue } else { c.Status(err.StatusCode) - _, _ = fmt.Fprint(c.Writer, err.Error.Error()) - flusher.Flush() + _, _ = c.Writer.Write([]byte(err.Error.Error())) cliCancel() } break } else { openAIFormat := translator.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey) if openAIFormat != "" { - _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat) - flusher.Flush() + _, _ = c.Writer.Write([]byte(openAIFormat)) } cliCancel() break @@ -309,7 +300,7 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) { } // Prepare the request for the backend client. - modelName, contents, tools := translator.PrepareRequest(rawJson) + modelName, systemInstruction, contents, tools := translator.PrepareRequest(rawJson) cliCtx, cliCancel := context.WithCancel(context.Background()) var cliClient *client.Client defer func() { @@ -369,7 +360,7 @@ outLoop: log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) } // Send the message and receive response chunks and errors via channels. - respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools) + respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, systemInstruction, contents, tools) hasFirstResponse := false for { select { @@ -420,3 +411,295 @@ outLoop: } } } + +func (h *APIHandlers) Internal(c *gin.Context) { + rawJson, _ := c.GetRawData() + requestRawURI := c.Request.URL.Path + if requestRawURI == "/v1internal:generateContent" { + h.internalGenerateContent(c, rawJson) + } else if requestRawURI == "/v1internal:streamGenerateContent" { + h.internalStreamGenerateContent(c, rawJson) + } else { + reqBody := bytes.NewBuffer(rawJson) + req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody) + if err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse{ + Error: ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + for key, value := range c.Request.Header { + req.Header[key] = value + } + + var transport *http.Transport + proxyURL, errParse := url.Parse(h.cfg.ProxyUrl) + if errParse == nil { + if proxyURL.Scheme == "socks5" { + username := proxyURL.User.Username() + password, _ := proxyURL.User.Password() + proxyAuth := &proxy.Auth{User: username, Password: password} + dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct) + if errSOCKS5 != nil { + log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5) + } + transport = &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + }, + } + } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { + transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} + } + } + httpClient := &http.Client{} + if transport != nil { + httpClient.Transport = transport + } + + resp, err := httpClient.Do(req) + if err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse{ + Error: ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + defer func() { + if err = resp.Body.Close(); err != nil { + log.Printf("warn: failed to close response body: %v", err) + } + }() + bodyBytes, _ := io.ReadAll(resp.Body) + + c.JSON(http.StatusBadRequest, ErrorResponse{ + Error: ErrorDetail{ + Message: string(bodyBytes), + Type: "invalid_request_error", + }, + }) + return + } + + defer func() { + _ = resp.Body.Close() + }() + + for key, value := range resp.Header { + c.Header(key, value[0]) + } + output, err := io.ReadAll(resp.Body) + if err != nil { + log.Errorf("Failed to read response body: %v", err) + return + } + _, _ = c.Writer.Write(output) + } +} + +func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []byte) { + // Get the http.Flusher interface to manually flush the response. + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, ErrorResponse{ + Error: ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + modelResult := gjson.GetBytes(rawJson, "model") + modelName := modelResult.String() + + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient *client.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + if cliClient != nil { + cliClient.RequestMutex.Unlock() + } + }() + +outLoop: + for { + // Lock the mutex to update the last used client index + mutex.Lock() + startIndex := lastUsedClientIndex + currentIndex := (startIndex + 1) % len(h.cliClients) + lastUsedClientIndex = currentIndex + mutex.Unlock() + + // Reorder the client to start from the last used index + reorderedClients := make([]*client.Client, 0) + for i := 0; i < len(h.cliClients); i++ { + cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)] + if cliClient.IsModelQuotaExceeded(modelName) { + log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID()) + cliClient = nil + continue + } + reorderedClients = append(reorderedClients, cliClient) + } + + if len(reorderedClients) == 0 { + c.Status(429) + _, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)) + flusher.Flush() + cliCancel() + return + } + + locked := false + for i := 0; i < len(reorderedClients); i++ { + cliClient = reorderedClients[i] + if cliClient.RequestMutex.TryLock() { + locked = true + break + } + } + if !locked { + cliClient = h.cliClients[0] + cliClient.RequestMutex.Lock() + } + + if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use generative language API Key: %s", glAPIKey) + } else { + log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) + } + // Send the message and receive response chunks and errors via channels. + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson) + hasFirstResponse := false + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("Client disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + cliCancel() + return + } else { + hasFirstResponse = true + if cliClient.GetGenerativeLanguageAPIKey() != "" { + chunk, _ = sjson.SetRawBytes(chunk, "response", chunk) + } + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write(chunk) + _, _ = c.Writer.Write([]byte("\n\n")) + flusher.Flush() + } + // Handle errors from the backend. + case err, okError := <-errChan: + if okError { + if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + continue outLoop + } else { + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + flusher.Flush() + cliCancel() + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + if hasFirstResponse { + _, _ = c.Writer.Write([]byte("\n")) + flusher.Flush() + } + } + } + } +} + +func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) { + c.Header("Content-Type", "application/json") + + modelResult := gjson.GetBytes(rawJson, "model") + modelName := modelResult.String() + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient *client.Client + defer func() { + if cliClient != nil { + cliClient.RequestMutex.Unlock() + } + }() + + for { + // Lock the mutex to update the last used client index + mutex.Lock() + startIndex := lastUsedClientIndex + currentIndex := (startIndex + 1) % len(h.cliClients) + lastUsedClientIndex = currentIndex + mutex.Unlock() + + // Reorder the client to start from the last used index + reorderedClients := make([]*client.Client, 0) + for i := 0; i < len(h.cliClients); i++ { + cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)] + if cliClient.IsModelQuotaExceeded(modelName) { + log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID()) + cliClient = nil + continue + } + reorderedClients = append(reorderedClients, cliClient) + } + + if len(reorderedClients) == 0 { + c.Status(429) + _, _ = c.Writer.Write([]byte(fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))) + cliCancel() + return + } + + locked := false + for i := 0; i < len(reorderedClients); i++ { + cliClient = reorderedClients[i] + if cliClient.RequestMutex.TryLock() { + locked = true + break + } + } + if !locked { + cliClient = h.cliClients[0] + cliClient.RequestMutex.Lock() + } + + if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use generative language API Key: %s", glAPIKey) + } else { + log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) + } + + resp, err := cliClient.SendRawMessage(cliCtx, rawJson) + if err != nil { + if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + continue + } else { + c.Status(err.StatusCode) + _, _ = c.Writer.Write([]byte(err.Error.Error())) + cliCancel() + } + break + } else { + _, _ = c.Writer.Write(resp) + cliCancel() + break + } + } +} diff --git a/internal/api/server.go b/internal/api/server.go index d507eabc..531c4a25 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -81,6 +81,8 @@ func (s *Server) setupRoutes() { }, }) }) + s.engine.POST("/v1internal:method", s.handlers.Internal) + } // Start begins listening for and serving HTTP requests. diff --git a/internal/api/translator/request.go b/internal/api/translator/request.go index 69a9794e..634e74e8 100644 --- a/internal/api/translator/request.go +++ b/internal/api/translator/request.go @@ -12,7 +12,7 @@ import ( // PrepareRequest translates a raw JSON request from an OpenAI-compatible format // to the internal format expected by the backend client. It parses messages, // roles, content types (text, image, file), and tool calls. -func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDeclaration) { +func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) { // Extract the model name from the request, defaulting to "gemini-2.5-pro". modelName := "gemini-2.5-pro" modelResult := gjson.GetBytes(rawJson, "model") @@ -22,6 +22,7 @@ func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDecl // Process the array of messages. contents := make([]client.Content, 0) + var systemInstruction *client.Content messagesResult := gjson.GetBytes(rawJson, "messages") if messagesResult.IsArray() { messagesResults := messagesResult.Array() @@ -37,13 +38,11 @@ func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDecl // System messages are converted to a user message followed by a model's acknowledgment. case "system": if contentResult.Type == gjson.String { - contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}}) - contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}}) + systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}} } else if contentResult.IsObject() { // Handle object-based system messages. if contentResult.Get("type").String() == "text" { - contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}}) - contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}}) + systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}} } } // User messages can contain simple text or a multi-part body. @@ -159,5 +158,5 @@ func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDecl tools = make([]client.ToolDeclaration, 0) } - return modelName, contents, tools + return modelName, systemInstruction, contents, tools } diff --git a/internal/client/client.go b/internal/client/client.go index 12d4f8f6..32793a12 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -266,6 +266,12 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface url = url + "?alt=sse" } jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw) + systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction") + if systemInstructionResult.Exists() { + jsonBody, _ = sjson.SetRawBytes(jsonBody, "system_instruction", []byte(systemInstructionResult.Raw)) + jsonBody, _ = sjson.DeleteBytes(jsonBody, "systemInstruction") + jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id") + } } // log.Debug(string(jsonBody)) @@ -303,15 +309,14 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface } }() bodyBytes, _ := io.ReadAll(resp.Body) - return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))} } return resp.Body, nil } -// SendMessageStream handles a single conversational turn, including tool calls. -func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) { +// SendMessage handles a single conversational turn, including tool calls. +func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) { request := GenerateContentRequest{ Contents: contents, GenerationConfig: GenerationConfig{ @@ -320,6 +325,9 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, }, }, } + + request.SystemInstruction = systemInstruction + request.Tools = tools requestBody := map[string]interface{}{ @@ -402,7 +410,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, } // SendMessageStream handles a single conversational turn, including tool calls. -func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) { +func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) { dataTag := []byte("data: ") errChan := make(chan *ErrorMessage) dataChan := make(chan []byte) @@ -418,6 +426,9 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st }, }, } + + request.SystemInstruction = systemInstruction + request.Tools = tools requestBody := map[string]interface{}{ @@ -519,6 +530,117 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st return dataChan, errChan } +// SendRawMessage handles a single conversational turn, including tool calls. +func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *ErrorMessage) { + rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID()) + + modelResult := gjson.GetBytes(rawJson, "model") + model := modelResult.String() + modelName := model + for { + if c.isModelQuotaExceeded(modelName) { + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { + modelName = c.getPreviewModel(model) + if modelName != "" { + log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) + rawJson, _ = sjson.SetBytes(rawJson, "model", modelName) + continue + } + } + return nil, &ErrorMessage{ + StatusCode: 429, + Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model), + } + } + + respBody, err := c.APIRequest(ctx, "generateContent", rawJson, false) + if err != nil { + if err.StatusCode == 429 { + now := time.Now() + c.modelQuotaExceeded[modelName] = &now + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { + continue + } + } + return nil, err + } + delete(c.modelQuotaExceeded, modelName) + bodyBytes, errReadAll := io.ReadAll(respBody) + if errReadAll != nil { + return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll} + } + return bodyBytes, nil + } +} + +// SendRawMessageStream handles a single conversational turn, including tool calls. +func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-chan []byte, <-chan *ErrorMessage) { + dataTag := []byte("data: ") + errChan := make(chan *ErrorMessage) + dataChan := make(chan []byte) + go func() { + defer close(errChan) + defer close(dataChan) + + rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID()) + + modelResult := gjson.GetBytes(rawJson, "model") + model := modelResult.String() + modelName := model + var stream io.ReadCloser + for { + if c.isModelQuotaExceeded(modelName) { + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { + modelName = c.getPreviewModel(model) + if modelName != "" { + log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) + rawJson, _ = sjson.SetBytes(rawJson, "model", modelName) + continue + } + } + errChan <- &ErrorMessage{ + StatusCode: 429, + Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model), + } + return + } + var err *ErrorMessage + stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJson, true) + if err != nil { + if err.StatusCode == 429 { + now := time.Now() + c.modelQuotaExceeded[modelName] = &now + if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { + continue + } + } + errChan <- err + return + } + delete(c.modelQuotaExceeded, modelName) + break + } + + scanner := bufio.NewScanner(stream) + for scanner.Scan() { + line := scanner.Bytes() + if bytes.HasPrefix(line, dataTag) { + dataChan <- line[6:] + } + } + + if errScanner := scanner.Err(); errScanner != nil { + errChan <- &ErrorMessage{500, errScanner} + _ = stream.Close() + return + } + + _ = stream.Close() + }() + + return dataChan, errChan +} + func (c *Client) isModelQuotaExceeded(model string) bool { if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey { duration := time.Now().Sub(*lastExceededTime) diff --git a/internal/client/models.go b/internal/client/models.go index a6aa9d44..eabc6b59 100644 --- a/internal/client/models.go +++ b/internal/client/models.go @@ -64,9 +64,10 @@ type FunctionResponse struct { // GenerateContentRequest is the top-level request structure for the streamGenerateContent endpoint. type GenerateContentRequest struct { - Contents []Content `json:"contents"` - Tools []ToolDeclaration `json:"tools,omitempty"` - GenerationConfig `json:"generationConfig"` + SystemInstruction *Content `json:"systemInstruction,omitempty"` + Contents []Content `json:"contents"` + Tools []ToolDeclaration `json:"tools,omitempty"` + GenerationConfig `json:"generationConfig"` } // GenerationConfig defines parameters that control the model's generation behavior.