feat(server): Add server configuration validation and optimize connection handling

- Add Validate method to ServerConfig to validate port ranges, domain format, TCP port ranges, and other configuration items
- Add configuration validation logic in server.go to ensure valid configuration before server startup
- Improve channel naming in TCP connections for better code readability
- Enhance data copying mechanism with context cancellation support to avoid resource leaks
- Add private network definitions for secure validation of trusted proxy headers

fix(proxy): Strengthen client IP extraction security and fix error handling

- Trust X-Forwarded-For and X-Real-IP headers only when requests originate from private/loopback networks
- Define RFC 1918 and other private network ranges for proxy header validation
- Add JSON serialization error handling in TCP connections to prevent data loss
- Fix context handling logic in pipe callbacks
- Optimize error handling mechanism for data connection responses

refactor(config): Improve client configuration validation and error handling

- Add Validate method to ClientConfig to verify server address format and port validity
- Change configuration validation from simple checks to full validation function calls
- Provide more detailed error messages to help users correctly configure server address formats
This commit is contained in:
Gouryella
2026-01-12 10:55:27 +08:00
parent 85a0f44e44
commit d7b92a8b95
7 changed files with 197 additions and 70 deletions

View File

@@ -21,17 +21,17 @@ import (
)
var (
serverPort int
serverPublicPort int
serverDomain string
serverAuthToken string
serverPort int
serverPublicPort int
serverDomain string
serverAuthToken string
serverMetricsToken string
serverDebug bool
serverTCPPortMin int
serverTCPPortMax int
serverTLSCert string
serverTLSKey string
serverPprofPort int
serverDebug bool
serverTCPPortMin int
serverTCPPortMax int
serverTLSCert string
serverTLSKey string
serverPprofPort int
)
var serverCmd = &cobra.Command{
@@ -113,6 +113,10 @@ func runServer(_ *cobra.Command, _ []string) error {
Debug: serverDebug,
}
if err := serverConfig.Validate(); err != nil {
logger.Fatal("Invalid server configuration", zap.Error(err))
}
tlsConfig, err := serverConfig.LoadTLSConfig()
if err != nil {
logger.Fatal("Failed to load TLS configuration", zap.Error(err))

View File

@@ -127,12 +127,12 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
return
}
done := make(chan struct{})
copyDone := make(chan struct{})
go func() {
select {
case <-ctx.Done():
stream.Close()
case <-done:
case <-copyDone:
}
}()
@@ -150,7 +150,7 @@ func (c *PoolClient) handleHTTPStream(stream net.Conn) {
break
}
}
close(done)
close(copyDone)
}
func (c *PoolClient) handleWebSocketUpgrade(cc net.Conn, req *http.Request) {

View File

@@ -41,6 +41,24 @@ type Handler struct {
metricsToken string
}
var privateNetworks []*net.IPNet
func init() {
privateCIDRs := []string{
"127.0.0.0/8", // IPv4 loopback
"10.0.0.0/8", // RFC 1918 Class A
"172.16.0.0/12", // RFC 1918 Class B
"192.168.0.0/16", // RFC 1918 Class C
"::1/128", // IPv6 loopback
"fc00::/7", // IPv6 unique local
"fe80::/10", // IPv6 link-local
}
for _, cidr := range privateCIDRs {
_, ipNet, _ := net.ParseCIDR(cidr)
privateNetworks = append(privateNetworks, ipNet)
}
}
func NewHandler(manager *tunnel.Manager, logger *zap.Logger, domain string, authToken string, metricsToken string) *Handler {
return &Handler{
manager: manager,
@@ -167,23 +185,23 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(statusCode)
// Use pooled buffer for zero-copy optimization
buf := pool.GetBuffer(pool.SizeLarge)
defer pool.PutBuffer(buf)
// Copy with context cancellation support
ctx := r.Context()
done := make(chan struct{})
copyDone := make(chan struct{})
go func() {
select {
case <-ctx.Done():
stream.Close()
case <-done:
case <-copyDone:
}
}()
// Use pooled buffer for zero-copy optimization
buf := pool.GetBuffer(pool.SizeLarge)
_, _ = io.CopyBuffer(w, resp.Body, (*buf)[:])
pool.PutBuffer(buf)
close(done)
stream.Close()
close(copyDone)
}
func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, error) {
@@ -192,24 +210,23 @@ func (h *Handler) openStreamWithTimeout(tconn *tunnel.Connection) (net.Conn, err
err error
}
ch := make(chan result, 1)
done := make(chan struct{})
defer close(done)
go func() {
s, err := tconn.OpenStream()
select {
case ch <- result{s, err}:
case <-done:
if s != nil {
s.Close()
}
}
ch <- result{s, err}
}()
select {
case r := <-ch:
return r.stream, r.err
case <-time.After(openStreamTimeout):
// Goroutine will eventually complete and send to buffered channel
// which will be garbage collected. If stream was opened, it needs cleanup.
go func() {
if r := <-ch; r.stream != nil {
r.stream.Close()
}
}()
return nil, fmt.Errorf("open stream timeout")
}
}
@@ -337,31 +354,58 @@ func (h *Handler) extractSubdomain(host string) string {
}
// extractClientIP extracts the client IP from the request.
// It checks X-Forwarded-For and X-Real-IP headers first (for reverse proxy setups),
// then falls back to the remote address.
// It only trusts X-Forwarded-For and X-Real-IP headers when the request
// comes from a private/loopback network (typical reverse proxy setup).
func (h *Handler) extractClientIP(r *http.Request) string {
// Check X-Forwarded-For header (may contain multiple IPs)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP (original client)
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// First, get the direct remote address
remoteIP := h.extractRemoteIP(r.RemoteAddr)
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
// Only trust proxy headers if the request comes from a private network
if isPrivateIP(remoteIP) {
// Check X-Forwarded-For header (may contain multiple IPs)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP (original client)
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
}
// Fall back to remote address
host, _, err := net.SplitHostPort(r.RemoteAddr)
return remoteIP
}
// extractRemoteIP extracts the IP address from a remote address string (host:port format).
func (h *Handler) extractRemoteIP(remoteAddr string) string {
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return r.RemoteAddr
return remoteAddr
}
return host
}
// isPrivateIP checks if the given IP is a private/loopback address.
func isPrivateIP(ip string) bool {
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
for _, network := range privateNetworks {
if network.Contains(parsedIP) {
return true
}
}
return false
}
func (h *Handler) serveHomePage(w http.ResponseWriter, r *http.Request) {
html := `<!DOCTYPE html>
<html lang="en">

View File

@@ -235,7 +235,10 @@ func (c *Connection) Handle() error {
RecommendedConns: recommendedConns,
}
respData, _ := json.Marshal(resp)
respData, err := json.Marshal(resp)
if err != nil {
return fmt.Errorf("failed to marshal registration response: %w", err)
}
ackFrame := protocol.NewFrame(protocol.FrameTypeRegisterAck, respData)
err = protocol.WriteFrame(c.conn, ackFrame)
@@ -409,13 +412,6 @@ func (c *Connection) handleHTTPRequestLegacy(reader *bufio.Reader) error {
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func parseTCPSubdomainPort(subdomain string) (int, bool) {
if !strings.HasPrefix(subdomain, "tcp-") {
return 0, false
@@ -525,11 +521,15 @@ func (c *Connection) sendError(code, message string) {
Code: code,
Message: message,
}
data, _ := json.Marshal(errMsg)
data, err := json.Marshal(errMsg)
if err != nil {
c.logger.Error("Failed to marshal error message", zap.Error(err))
return
}
errFrame := protocol.NewFrame(protocol.FrameTypeError, data)
if c.frameWriter == nil {
protocol.WriteFrame(c.conn, errFrame)
_ = protocol.WriteFrame(c.conn, errFrame)
} else {
c.frameWriter.WriteFrame(errFrame)
}
@@ -676,7 +676,10 @@ func (c *Connection) handleDataConnect(frame *protocol.Frame, reader *bufio.Read
Message: "Data connection accepted",
}
respData, _ := json.Marshal(resp)
respData, err := json.Marshal(resp)
if err != nil {
return fmt.Errorf("failed to marshal data connect response: %w", err)
}
ackFrame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
if err := protocol.WriteFrame(c.conn, ackFrame); err != nil {
@@ -732,7 +735,11 @@ func (c *Connection) sendDataConnectError(code, message string) {
Accepted: false,
Message: fmt.Sprintf("%s: %s", code, message),
}
respData, _ := json.Marshal(resp)
respData, err := json.Marshal(resp)
if err != nil {
c.logger.Error("Failed to marshal data connect error", zap.Error(err))
return
}
frame := protocol.NewFrame(protocol.FrameTypeDataConnectAck, respData)
protocol.WriteFrame(c.conn, frame)
_ = protocol.WriteFrame(c.conn, frame)
}

View File

@@ -64,16 +64,6 @@ func PipeWithCallbacksAndBufferSize(ctx context.Context, a, b io.ReadWriteCloser
errCh := make(chan error, 2)
if ctx != nil {
go func() {
select {
case <-ctx.Done():
closeAll()
case <-stopCh:
}
}()
}
go func() {
defer wg.Done()
err := pipeBuffer(b, a, bufSize, onAToB, stopCh)
@@ -92,6 +82,16 @@ func PipeWithCallbacksAndBufferSize(ctx context.Context, a, b io.ReadWriteCloser
closeAll()
}()
if ctx != nil {
go func() {
select {
case <-ctx.Done():
closeAll()
case <-stopCh:
}
}()
}
wg.Wait()
select {

View File

@@ -2,8 +2,10 @@ package config
import (
"fmt"
"net"
"os"
"path/filepath"
"strings"
"gopkg.in/yaml.v3"
)
@@ -15,6 +17,31 @@ type ClientConfig struct {
TLS bool `yaml:"tls"` // Use TLS (always true for production)
}
// Validate checks if the client configuration is valid
func (c *ClientConfig) Validate() error {
if c.Server == "" {
return fmt.Errorf("server address is required")
}
host, port, err := net.SplitHostPort(c.Server)
if err != nil {
if strings.Contains(err.Error(), "missing port") {
return fmt.Errorf("server address must include port (e.g., example.com:443), got: %s", c.Server)
}
return fmt.Errorf("invalid server address format: %s (expected host:port)", c.Server)
}
if host == "" {
return fmt.Errorf("server host is required")
}
if port == "" {
return fmt.Errorf("server port is required")
}
return nil
}
// DefaultClientConfig returns the default configuration path
func DefaultClientConfigPath() string {
home, err := os.UserHomeDir()
@@ -43,8 +70,8 @@ func LoadClientConfig(path string) (*ClientConfig, error) {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
if config.Server == "" {
return nil, fmt.Errorf("server address is required in config")
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
}
return &config, nil

View File

@@ -4,6 +4,7 @@ import (
"crypto/tls"
"fmt"
"os"
"strings"
)
// ServerConfig holds the server configuration
@@ -30,6 +31,50 @@ type ServerConfig struct {
Debug bool
}
// Validate checks if the server configuration is valid
func (c *ServerConfig) Validate() error {
// Validate port
if c.Port < 1 || c.Port > 65535 {
return fmt.Errorf("invalid port %d: must be between 1 and 65535", c.Port)
}
// Validate public port if set
if c.PublicPort != 0 && (c.PublicPort < 1 || c.PublicPort > 65535) {
return fmt.Errorf("invalid public port %d: must be between 1 and 65535", c.PublicPort)
}
// Validate domain
if c.Domain == "" {
return fmt.Errorf("domain is required")
}
if strings.Contains(c.Domain, ":") {
return fmt.Errorf("domain should not contain port, got: %s", c.Domain)
}
// Validate TCP port range
if c.TCPPortMin < 1 || c.TCPPortMin > 65535 {
return fmt.Errorf("invalid TCPPortMin %d: must be between 1 and 65535", c.TCPPortMin)
}
if c.TCPPortMax < 1 || c.TCPPortMax > 65535 {
return fmt.Errorf("invalid TCPPortMax %d: must be between 1 and 65535", c.TCPPortMax)
}
if c.TCPPortMin >= c.TCPPortMax {
return fmt.Errorf("TCPPortMin (%d) must be less than TCPPortMax (%d)", c.TCPPortMin, c.TCPPortMax)
}
// Validate TLS settings
if c.TLSEnabled {
if c.TLSCertFile == "" {
return fmt.Errorf("TLS certificate file is required when TLS is enabled")
}
if c.TLSKeyFile == "" {
return fmt.Errorf("TLS key file is required when TLS is enabled")
}
}
return nil
}
// LoadTLSConfig loads TLS configuration
func (c *ServerConfig) LoadTLSConfig() (*tls.Config, error) {
if !c.TLSEnabled {