From bea5f97cbfd3e4f3ca265bfd771b800a0dbc6e70 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Thu, 28 Aug 2025 00:30:46 +0800 Subject: [PATCH] Add `/v1/completions` endpoint with OpenAI compatibility - Implemented `/v1/completions` endpoint mirroring OpenAI's completions API specification. - Added conversion functions to translate between completions and chat completions formats. - Introduced streaming and non-streaming response handling for completions requests. - Updated `server.go` to register the new endpoint and include it in the API's metadata. --- .../api/handlers/openai/openai_handlers.go | 446 ++++++++++++++++++ internal/api/server.go | 2 + internal/registry/model_registry.go | 2 +- 3 files changed, 449 insertions(+), 1 deletion(-) diff --git a/internal/api/handlers/openai/openai_handlers.go b/internal/api/handlers/openai/openai_handlers.go index bd8d0aef..e8059264 100644 --- a/internal/api/handlers/openai/openai_handlers.go +++ b/internal/api/handlers/openai/openai_handlers.go @@ -8,6 +8,7 @@ package openai import ( "context" + "encoding/json" "fmt" "net/http" "time" @@ -19,6 +20,7 @@ import ( "github.com/luispater/CLIProxyAPI/internal/registry" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) // OpenAIAPIHandler contains the handlers for OpenAI API endpoints. @@ -92,6 +94,276 @@ func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) { } +// Completions handles the /v1/completions endpoint. +// It determines whether the request is for a streaming or non-streaming response +// and calls the appropriate handler based on the model provider. +// This endpoint follows the OpenAI completions API specification. +// +// Parameters: +// - c: The Gin context containing the HTTP request and response +func (h *OpenAIAPIHandler) Completions(c *gin.Context) { + rawJSON, err := c.GetRawData() + // If data retrieval fails, return a 400 Bad Request error. + if err != nil { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + + // Check if the client requested a streaming response. + streamResult := gjson.GetBytes(rawJSON, "stream") + if streamResult.Type == gjson.True { + h.handleCompletionsStreamingResponse(c, rawJSON) + } else { + h.handleCompletionsNonStreamingResponse(c, rawJSON) + } + +} + +// convertCompletionsRequestToChatCompletions converts OpenAI completions API request to chat completions format. +// This allows the completions endpoint to use the existing chat completions infrastructure. +// +// Parameters: +// - rawJSON: The raw JSON bytes of the completions request +// +// Returns: +// - []byte: The converted chat completions request +func convertCompletionsRequestToChatCompletions(rawJSON []byte) []byte { + root := gjson.ParseBytes(rawJSON) + + // Extract prompt from completions request + prompt := root.Get("prompt").String() + if prompt == "" { + prompt = "Complete this:" + } + + // Create chat completions structure + out := `{"model":"","messages":[{"role":"user","content":""}]}` + + // Set model + if model := root.Get("model"); model.Exists() { + out, _ = sjson.Set(out, "model", model.String()) + } + + // Set the prompt as user message content + out, _ = sjson.Set(out, "messages.0.content", prompt) + + // Copy other parameters from completions to chat completions + if maxTokens := root.Get("max_tokens"); maxTokens.Exists() { + out, _ = sjson.Set(out, "max_tokens", maxTokens.Int()) + } + + if temperature := root.Get("temperature"); temperature.Exists() { + out, _ = sjson.Set(out, "temperature", temperature.Float()) + } + + if topP := root.Get("top_p"); topP.Exists() { + out, _ = sjson.Set(out, "top_p", topP.Float()) + } + + if frequencyPenalty := root.Get("frequency_penalty"); frequencyPenalty.Exists() { + out, _ = sjson.Set(out, "frequency_penalty", frequencyPenalty.Float()) + } + + if presencePenalty := root.Get("presence_penalty"); presencePenalty.Exists() { + out, _ = sjson.Set(out, "presence_penalty", presencePenalty.Float()) + } + + if stop := root.Get("stop"); stop.Exists() { + out, _ = sjson.SetRaw(out, "stop", stop.Raw) + } + + if stream := root.Get("stream"); stream.Exists() { + out, _ = sjson.Set(out, "stream", stream.Bool()) + } + + if logprobs := root.Get("logprobs"); logprobs.Exists() { + out, _ = sjson.Set(out, "logprobs", logprobs.Bool()) + } + + if topLogprobs := root.Get("top_logprobs"); topLogprobs.Exists() { + out, _ = sjson.Set(out, "top_logprobs", topLogprobs.Int()) + } + + if echo := root.Get("echo"); echo.Exists() { + out, _ = sjson.Set(out, "echo", echo.Bool()) + } + + return []byte(out) +} + +// convertChatCompletionsResponseToCompletions converts chat completions API response back to completions format. +// This ensures the completions endpoint returns data in the expected format. +// +// Parameters: +// - rawJSON: The raw JSON bytes of the chat completions response +// +// Returns: +// - []byte: The converted completions response +func convertChatCompletionsResponseToCompletions(rawJSON []byte) []byte { + root := gjson.ParseBytes(rawJSON) + + // Base completions response structure + out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}` + + // Copy basic fields + if id := root.Get("id"); id.Exists() { + out, _ = sjson.Set(out, "id", id.String()) + } + + if created := root.Get("created"); created.Exists() { + out, _ = sjson.Set(out, "created", created.Int()) + } + + if model := root.Get("model"); model.Exists() { + out, _ = sjson.Set(out, "model", model.String()) + } + + if usage := root.Get("usage"); usage.Exists() { + out, _ = sjson.SetRaw(out, "usage", usage.Raw) + } + + // Convert choices from chat completions to completions format + var choices []interface{} + if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() { + chatChoices.ForEach(func(_, choice gjson.Result) bool { + completionsChoice := map[string]interface{}{ + "index": choice.Get("index").Int(), + } + + // Extract text content from message.content + if message := choice.Get("message"); message.Exists() { + if content := message.Get("content"); content.Exists() { + completionsChoice["text"] = content.String() + } + } else if delta := choice.Get("delta"); delta.Exists() { + // For streaming responses, use delta.content + if content := delta.Get("content"); content.Exists() { + completionsChoice["text"] = content.String() + } + } + + // Copy finish_reason + if finishReason := choice.Get("finish_reason"); finishReason.Exists() { + completionsChoice["finish_reason"] = finishReason.String() + } + + // Copy logprobs if present + if logprobs := choice.Get("logprobs"); logprobs.Exists() { + completionsChoice["logprobs"] = logprobs.Value() + } + + choices = append(choices, completionsChoice) + return true + }) + } + + if len(choices) > 0 { + choicesJSON, _ := json.Marshal(choices) + out, _ = sjson.SetRaw(out, "choices", string(choicesJSON)) + } + + return []byte(out) +} + +// convertChatCompletionsStreamChunkToCompletions converts a streaming chat completions chunk to completions format. +// This handles the real-time conversion of streaming response chunks and filters out empty text responses. +// +// Parameters: +// - chunkData: The raw JSON bytes of a single chat completions stream chunk +// +// Returns: +// - []byte: The converted completions stream chunk, or nil if should be filtered out +func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte { + root := gjson.ParseBytes(chunkData) + + // Check if this chunk has any meaningful content + hasContent := false + if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() { + chatChoices.ForEach(func(_, choice gjson.Result) bool { + // Check if delta has content or finish_reason + if delta := choice.Get("delta"); delta.Exists() { + if content := delta.Get("content"); content.Exists() && content.String() != "" { + hasContent = true + return false // Break out of forEach + } + } + // Also check for finish_reason to ensure we don't skip final chunks + if finishReason := choice.Get("finish_reason"); finishReason.Exists() && finishReason.String() != "" && finishReason.String() != "null" { + hasContent = true + return false // Break out of forEach + } + return true + }) + } + + // If no meaningful content, return nil to indicate this chunk should be skipped + if !hasContent { + return nil + } + + // Base completions stream response structure + out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}` + + // Copy basic fields + if id := root.Get("id"); id.Exists() { + out, _ = sjson.Set(out, "id", id.String()) + } + + if created := root.Get("created"); created.Exists() { + out, _ = sjson.Set(out, "created", created.Int()) + } + + if model := root.Get("model"); model.Exists() { + out, _ = sjson.Set(out, "model", model.String()) + } + + // Convert choices from chat completions delta to completions format + var choices []interface{} + if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() { + chatChoices.ForEach(func(_, choice gjson.Result) bool { + completionsChoice := map[string]interface{}{ + "index": choice.Get("index").Int(), + } + + // Extract text content from delta.content + if delta := choice.Get("delta"); delta.Exists() { + if content := delta.Get("content"); content.Exists() && content.String() != "" { + completionsChoice["text"] = content.String() + } else { + completionsChoice["text"] = "" + } + } else { + completionsChoice["text"] = "" + } + + // Copy finish_reason + if finishReason := choice.Get("finish_reason"); finishReason.Exists() && finishReason.String() != "null" { + completionsChoice["finish_reason"] = finishReason.String() + } + + // Copy logprobs if present + if logprobs := choice.Get("logprobs"); logprobs.Exists() { + completionsChoice["logprobs"] = logprobs.Value() + } + + choices = append(choices, completionsChoice) + return true + }) + } + + if len(choices) > 0 { + choicesJSON, _ := json.Marshal(choices) + out, _ = sjson.SetRaw(out, "choices", string(choicesJSON)) + } + + return []byte(out) +} + // handleNonStreamingResponse handles non-streaming chat completion responses // for Gemini models. It selects a client from the pool, sends the request, and // aggregates the response before sending it back to the client in OpenAI format. @@ -251,3 +523,177 @@ outLoop: } } } + +// handleCompletionsNonStreamingResponse handles non-streaming completions responses. +// It converts completions request to chat completions format, sends to backend, +// then converts the response back to completions format before sending to client. +// +// Parameters: +// - c: The Gin context containing the HTTP request and response +// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request +func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "application/json") + + // Convert completions request to chat completions format + chatCompletionsJSON := convertCompletionsRequestToChatCompletions(rawJSON) + + modelName := gjson.GetBytes(chatCompletionsJSON, "model").String() + cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) + + var cliClient interfaces.Client + defer func() { + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + + retryCount := 0 + for retryCount <= h.Cfg.RequestRetry { + var errorResponse *interfaces.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error()) + cliCancel() + return + } + + // Send the converted chat completions request + resp, err := cliClient.SendRawMessage(cliCtx, modelName, chatCompletionsJSON, "") + if err != nil { + switch err.StatusCode { + case 429: + if h.Cfg.QuotaExceeded.SwitchProject { + log.Debugf("quota exceeded, switch client") + continue // Restart the client selection process + } + case 403, 408, 500, 502, 503, 504: + log.Debugf("http status code %d, switch client", err.StatusCode) + retryCount++ + continue + default: + // Forward other errors directly to the client + c.Status(err.StatusCode) + _, _ = c.Writer.Write([]byte(err.Error.Error())) + cliCancel(err.Error) + } + break + } else { + // Convert chat completions response back to completions format + completionsResp := convertChatCompletionsResponseToCompletions(resp) + _, _ = c.Writer.Write(completionsResp) + cliCancel(completionsResp) + break + } + } +} + +// handleCompletionsStreamingResponse handles streaming completions responses. +// It converts completions request to chat completions format, streams from backend, +// then converts each response chunk back to completions format before sending to client. +// +// Parameters: +// - c: The Gin context containing the HTTP request and response +// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request +func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) { + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Access-Control-Allow-Origin", "*") + + // Get the http.Flusher interface to manually flush the response. + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + // Convert completions request to chat completions format + chatCompletionsJSON := convertCompletionsRequestToChatCompletions(rawJSON) + + modelName := gjson.GetBytes(chatCompletionsJSON, "model").String() + cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) + + var cliClient interfaces.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + if cliClient != nil { + cliClient.GetRequestMutex().Unlock() + } + }() + + retryCount := 0 +outLoop: + for retryCount <= h.Cfg.RequestRetry { + var errorResponse *interfaces.ErrorMessage + cliClient, errorResponse = h.GetClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error()) + flusher.Flush() + cliCancel() + return + } + + // Send the converted chat completions request and receive response chunks + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, modelName, chatCompletionsJSON, "") + + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("client disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + // Stream is closed, send the final [DONE] message. + _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") + flusher.Flush() + cliCancel() + return + } + + // Convert chat completions chunk to completions chunk format + completionsChunk := convertChatCompletionsStreamChunkToCompletions(chunk) + // Skip this chunk if it has no meaningful content (empty text) + if completionsChunk != nil { + _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(completionsChunk)) + flusher.Flush() + } + // Handle errors from the backend. + case err, okError := <-errChan: + if okError { + switch err.StatusCode { + case 429: + if h.Cfg.QuotaExceeded.SwitchProject { + log.Debugf("quota exceeded, switch client") + continue outLoop // Restart the client selection process + } + case 403, 408, 500, 502, 503, 504: + log.Debugf("http status code %d, switch client", err.StatusCode) + retryCount++ + continue outLoop + default: + // Forward other errors directly to the client + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + flusher.Flush() + cliCancel(err.Error) + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + } + } + } +} diff --git a/internal/api/server.go b/internal/api/server.go index 381c765e..7216707a 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -104,6 +104,7 @@ func (s *Server) setupRoutes() { { v1.GET("/models", s.unifiedModelsHandler(openaiHandlers, claudeCodeHandlers)) v1.POST("/chat/completions", openaiHandlers.ChatCompletions) + v1.POST("/completions", openaiHandlers.Completions) v1.POST("/messages", claudeCodeHandlers.ClaudeMessages) } @@ -123,6 +124,7 @@ func (s *Server) setupRoutes() { "version": "1.0.0", "endpoints": []string{ "POST /v1/chat/completions", + "POST /v1/completions", "GET /v1/models", }, }) diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index 0221f5fb..619b0e11 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -185,7 +185,7 @@ func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) { if registration, exists := r.models[modelID]; exists { delete(registration.QuotaExceededClients, clientID) - log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID) + // log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID) } }