mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 21:00:44 +00:00
feat(server): Add server configuration validation and optimize connection handling
- Add Validate method to ServerConfig to validate port ranges, domain format, TCP port ranges, and other configuration items - Add configuration validation logic in server.go to ensure valid configuration before server startup - Improve channel naming in TCP connections for better code readability - Enhance data copying mechanism with context cancellation support to avoid resource leaks - Add private network definitions for secure validation of trusted proxy headers fix(proxy): Strengthen client IP extraction security and fix error handling - Trust X-Forwarded-For and X-Real-IP headers only when requests originate from private/loopback networks - Define RFC 1918 and other private network ranges for proxy header validation - Add JSON serialization error handling in TCP connections to prevent data loss - Fix context handling logic in pipe callbacks - Optimize error handling mechanism for data connection responses refactor(config): Improve client configuration validation and error handling - Add Validate method to ClientConfig to verify server address format and port validity - Change configuration validation from simple checks to full validation function calls - Provide more detailed error messages to help users correctly configure server address formats
This commit is contained in:
@@ -21,17 +21,17 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
serverPort int
|
||||
serverPublicPort int
|
||||
serverDomain string
|
||||
serverAuthToken string
|
||||
serverPort int
|
||||
serverPublicPort int
|
||||
serverDomain string
|
||||
serverAuthToken string
|
||||
serverMetricsToken string
|
||||
serverDebug bool
|
||||
serverTCPPortMin int
|
||||
serverTCPPortMax int
|
||||
serverTLSCert string
|
||||
serverTLSKey string
|
||||
serverPprofPort int
|
||||
serverDebug bool
|
||||
serverTCPPortMin int
|
||||
serverTCPPortMax int
|
||||
serverTLSCert string
|
||||
serverTLSKey string
|
||||
serverPprofPort int
|
||||
)
|
||||
|
||||
var serverCmd = &cobra.Command{
|
||||
@@ -113,6 +113,10 @@ func runServer(_ *cobra.Command, _ []string) error {
|
||||
Debug: serverDebug,
|
||||
}
|
||||
|
||||
if err := serverConfig.Validate(); err != nil {
|
||||
logger.Fatal("Invalid server configuration", zap.Error(err))
|
||||
}
|
||||
|
||||
tlsConfig, err := serverConfig.LoadTLSConfig()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to load TLS configuration", zap.Error(err))
|
||||
|
||||
@@ -127,12 +127,12 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
copyDone := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
stream.Close()
|
||||
case <-done:
|
||||
case <-copyDone:
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -150,7 +150,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
|
||||
break
|
||||
}
|
||||
}
|
||||
close(done)
|
||||
close(copyDone)
|
||||
}
|
||||
|
||||
func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {
|
||||
|
||||
@@ -41,6 +41,24 @@ type Handler struct {
|
||||
metricsToken string
|
||||
}
|
||||
|
||||
var privateNetworks []*net.IPNet
|
||||
|
||||
func init() {
|
||||
privateCIDRs := []string{
|
||||
"127.0.0.0/8", // IPv4 loopback
|
||||
"10.0.0.0/8", // RFC 1918 Class A
|
||||
"172.16.0.0/12", // RFC 1918 Class B
|
||||
"192.168.0.0/16", // RFC 1918 Class C
|
||||
"::1/128", // IPv6 loopback
|
||||
"fc00::/7", // IPv6 unique local
|
||||
"fe80::/10", // IPv6 link-local
|
||||
}
|
||||
for _, cidr := range privateCIDRs {
|
||||
_, ipNet, _ := net.ParseCIDR(cidr)
|
||||
privateNetworks = append(privateNetworks, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, authToken string, metricsToken string) *Handler {
|
||||
return &Handler{
|
||||
manager: manager,
|
||||
@@ -167,23 +185,23 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
// Use pooled buffer for zero-copy optimization
|
||||
buf := pool.GetBuffer(pool.SizeLarge)
|
||||
defer pool.PutBuffer(buf)
|
||||
|
||||
// Copy with context cancellation support
|
||||
ctx := r.Context()
|
||||
done := make(chan struct{})
|
||||
copyDone := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
stream.Close()
|
||||
case <-done:
|
||||
case <-copyDone:
|
||||
}
|
||||
}()
|
||||
|
||||
// Use pooled buffer for zero-copy optimization
|
||||
buf := pool.GetBuffer(pool.SizeLarge)
|
||||
_, _ = io.CopyBuffer(w, resp.Body, (*buf)[:])
|
||||
pool.PutBuffer(buf)
|
||||
|
||||
close(done)
|
||||
stream.Close()
|
||||
close(copyDone)
|
||||
}
|
||||
|
||||
func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, error) {
|
||||
@@ -192,24 +210,23 @@ func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, err
|
||||
err error
|
||||
}
|
||||
ch := make(chan result, 1)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
go func() {
|
||||
s, err := tconn.OpenStream()
|
||||
select {
|
||||
case ch <- result{s, err}:
|
||||
case <-done:
|
||||
if s != nil {
|
||||
s.Close()
|
||||
}
|
||||
}
|
||||
ch <- result{s, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case r := <-ch:
|
||||
return r.stream, r.err
|
||||
case <-time.After(openStreamTimeout):
|
||||
// Goroutine will eventually complete and send to buffered channel
|
||||
// which will be garbage collected. If stream was opened, it needs cleanup.
|
||||
go func() {
|
||||
if r := <-ch; r.stream != nil {
|
||||
r.stream.Close()
|
||||
}
|
||||
}()
|
||||
return nil, fmt.Errorf("open stream timeout")
|
||||
}
|
||||
}
|
||||
@@ -337,31 +354,58 @@ func (h *Handler) extractSubdomain(host string) string {
|
||||
}
|
||||
|
||||
// extractClientIP extracts the client IP from the request.
|
||||
// It checks X-Forwarded-For and X-Real-IP headers first (for reverse proxy setups),
|
||||
// then falls back to the remote address.
|
||||
// It only trusts X-Forwarded-For and X-Real-IP headers when the request
|
||||
// comes from a private/loopback network (typical reverse proxy setup).
|
||||
func (h *Handler) extractClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header (may contain multiple IPs)
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP (original client)
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
// First, get the direct remote address
|
||||
remoteIP := h.extractRemoteIP(r.RemoteAddr)
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return strings.TrimSpace(xri)
|
||||
// Only trust proxy headers if the request comes from a private network
|
||||
if isPrivateIP(remoteIP) {
|
||||
// Check X-Forwarded-For header (may contain multiple IPs)
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP (original client)
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return strings.TrimSpace(xri)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to remote address
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
// extractRemoteIP extracts the IP address from a remote address string (host:port format).
|
||||
func (h *Handler) extractRemoteIP(remoteAddr string) string {
|
||||
host, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
return remoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// isPrivateIP checks if the given IP is a private/loopback address.
|
||||
func isPrivateIP(ip string) bool {
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, network := range privateNetworks {
|
||||
if network.Contains(parsedIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) {
|
||||
html := `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
@@ -235,7 +235,10 @@ func (c *Connection) Handle() error {
|
||||
RecommendedConns: recommendedConns,
|
||||
}
|
||||
|
||||
respData, _ := json.Marshal(resp)
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal registration response: %w", err)
|
||||
}
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData)
|
||||
|
||||
err = protocol.WriteFrame(c.conn, ackFrame)
|
||||
@@ -409,13 +412,6 @@ func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
|
||||
}
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func parseTCPSubdomainPort(subdomain string) (int, bool) {
|
||||
if !strings.HasPrefix(subdomain, "tcp-") {
|
||||
return 0, false
|
||||
@@ -525,11 +521,15 @@ func (c *Connection) sendError(code, message string) {
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
data, _ := json.Marshal(errMsg)
|
||||
data, err := json.Marshal(errMsg)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to marshal error message", zap.Error(err))
|
||||
return
|
||||
}
|
||||
errFrame := protocol.NewFrame(protocol.FrameTypeError, data)
|
||||
|
||||
if c.frameWriter == nil {
|
||||
protocol.WriteFrame(c.conn, errFrame)
|
||||
_ = protocol.WriteFrame(c.conn, errFrame)
|
||||
} else {
|
||||
c.frameWriter.WriteFrame(errFrame)
|
||||
}
|
||||
@@ -676,7 +676,10 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
|
||||
Message: "Data connection accepted",
|
||||
}
|
||||
|
||||
respData, _ := json.Marshal(resp)
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal data connect response: %w", err)
|
||||
}
|
||||
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
|
||||
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
|
||||
@@ -732,7 +735,11 @@ func (c *Connection) sendDataConnectError(code, message string) {
|
||||
Accepted: false,
|
||||
Message: fmt.Sprintf("%s: %s", code, message),
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
respData, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
c.logger.Error("Failed to marshal data connect error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
|
||||
protocol.WriteFrame(c.conn, frame)
|
||||
_ = protocol.WriteFrame(c.conn, frame)
|
||||
}
|
||||
|
||||
@@ -64,16 +64,6 @@ func PipeWithCallbacksAndBufferSize(ctx context.Context, a, b io.ReadWriteCloser
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
if ctx != nil {
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
closeAll()
|
||||
case <-stopCh:
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := pipeBuffer(b, a, bufSize, onAToB, stopCh)
|
||||
@@ -92,6 +82,16 @@ func PipeWithCallbacksAndBufferSize(ctx context.Context, a, b io.ReadWriteCloser
|
||||
closeAll()
|
||||
}()
|
||||
|
||||
if ctx != nil {
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
closeAll()
|
||||
case <-stopCh:
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
|
||||
@@ -2,8 +2,10 @@ package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -15,6 +17,31 @@ type ClientConfig struct {
|
||||
TLS bool `yaml:"tls"` // Use TLS (always true for production)
|
||||
}
|
||||
|
||||
// Validate checks if the client configuration is valid
|
||||
func (c *ClientConfig) Validate() error {
|
||||
if c.Server == "" {
|
||||
return fmt.Errorf("server address is required")
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(c.Server)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "missing port") {
|
||||
return fmt.Errorf("server address must include port (e.g., example.com:443), got: %s", c.Server)
|
||||
}
|
||||
return fmt.Errorf("invalid server address format: %s (expected host:port)", c.Server)
|
||||
}
|
||||
|
||||
if host == "" {
|
||||
return fmt.Errorf("server host is required")
|
||||
}
|
||||
|
||||
if port == "" {
|
||||
return fmt.Errorf("server port is required")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultClientConfig returns the default configuration path
|
||||
func DefaultClientConfigPath() string {
|
||||
home, err := os.UserHomeDir()
|
||||
@@ -43,8 +70,8 @@ func LoadClientConfig(path string) (*ClientConfig, error) {
|
||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
if config.Server == "" {
|
||||
return nil, fmt.Errorf("server address is required in config")
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ServerConfig holds the server configuration
|
||||
@@ -30,6 +31,50 @@ type ServerConfig struct {
|
||||
Debug bool
|
||||
}
|
||||
|
||||
// Validate checks if the server configuration is valid
|
||||
func (c *ServerConfig) Validate() error {
|
||||
// Validate port
|
||||
if c.Port < 1 || c.Port > 65535 {
|
||||
return fmt.Errorf("invalid port %d: must be between 1 and 65535", c.Port)
|
||||
}
|
||||
|
||||
// Validate public port if set
|
||||
if c.PublicPort != 0 && (c.PublicPort < 1 || c.PublicPort > 65535) {
|
||||
return fmt.Errorf("invalid public port %d: must be between 1 and 65535", c.PublicPort)
|
||||
}
|
||||
|
||||
// Validate domain
|
||||
if c.Domain == "" {
|
||||
return fmt.Errorf("domain is required")
|
||||
}
|
||||
if strings.Contains(c.Domain, ":") {
|
||||
return fmt.Errorf("domain should not contain port, got: %s", c.Domain)
|
||||
}
|
||||
|
||||
// Validate TCP port range
|
||||
if c.TCPPortMin < 1 || c.TCPPortMin > 65535 {
|
||||
return fmt.Errorf("invalid TCPPortMin %d: must be between 1 and 65535", c.TCPPortMin)
|
||||
}
|
||||
if c.TCPPortMax < 1 || c.TCPPortMax > 65535 {
|
||||
return fmt.Errorf("invalid TCPPortMax %d: must be between 1 and 65535", c.TCPPortMax)
|
||||
}
|
||||
if c.TCPPortMin >= c.TCPPortMax {
|
||||
return fmt.Errorf("TCPPortMin (%d) must be less than TCPPortMax (%d)", c.TCPPortMin, c.TCPPortMax)
|
||||
}
|
||||
|
||||
// Validate TLS settings
|
||||
if c.TLSEnabled {
|
||||
if c.TLSCertFile == "" {
|
||||
return fmt.Errorf("TLS certificate file is required when TLS is enabled")
|
||||
}
|
||||
if c.TLSKeyFile == "" {
|
||||
return fmt.Errorf("TLS key file is required when TLS is enabled")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadTLSConfig loads TLS configuration
|
||||
func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
|
||||
if !c.TLSEnabled {
|
||||
|
||||
Reference in New Issue
Block a user