Files
drip/internal/shared/netutil/pipe.go
Gouryella d7b92a8b95 feat(server): Add server configuration validation and optimize connection handling
- Add Validate method to ServerConfig to validate port ranges, domain format, TCP port ranges, and other configuration items
- Add configuration validation logic in server.go to ensure valid configuration before server startup
- Improve channel naming in TCP connections for better code readability
- Enhance data copying mechanism with context cancellation support to avoid resource leaks
- Add private network definitions for secure validation of trusted proxy headers

fix(proxy): Strengthen client IP extraction security and fix error handling

- Trust X-Forwarded-For and X-Real-IP headers only when requests originate from private/loopback networks
- Define RFC 1918 and other private network ranges for proxy header validation
- Add JSON serialization error handling in TCP connections to prevent data loss
- Fix context handling logic in pipe callbacks
- Optimize error handling mechanism for data connection responses

refactor(config): Improve client configuration validation and error handling

- Add Validate method to ClientConfig to verify server address format and port validity
- Change configuration validation from simple checks to full validation function calls
- Provide more detailed error messages to help users correctly configure server address formats
2026-01-12 10:55:27 +08:00

165 lines
3.4 KiB
Go

package netutil
import (
"context"
"io"
"sync"
"time"
"drip/internal/shared/pool"
)
const tcpWaitTimeout = 10 * time.Second
type closeReader interface {
CloseRead() error
}
type closeWriter interface {
CloseWrite() error
}
type readDeadliner interface {
SetReadDeadline(t time.Time) error
}
// Pipe copies bytes bidirectionally between a and b (gost-like),
// and applies TCP half-close when supported.
func Pipe(ctx context.Context, a, b io.ReadWriteCloser) error {
return PipeWithCallbacksAndBufferSize(ctx, a, b, pool.SizeMedium, nil, nil)
}
// PipeWithCallbacks is Pipe with optional byte counters for each direction:
// onAToB is called with bytes copied from a -> b, onBToA for b -> a.
func PipeWithCallbacks(ctx context.Context, a, b io.ReadWriteCloser, onAToB func(n int64), onBToA func(n int64)) error {
return PipeWithCallbacksAndBufferSize(ctx, a, b, pool.SizeMedium, onAToB, onBToA)
}
// PipeWithBufferSize is Pipe with a custom buffer size.
func PipeWithBufferSize(ctx context.Context, a, b io.ReadWriteCloser, bufSize int) error {
return PipeWithCallbacksAndBufferSize(ctx, a, b, bufSize, nil, nil)
}
// PipeWithCallbacksAndBufferSize is PipeWithCallbacks with a custom buffer size.
func PipeWithCallbacksAndBufferSize(ctx context.Context, a, b io.ReadWriteCloser, bufSize int, onAToB func(n int64), onBToA func(n int64)) error {
if bufSize <= 0 {
bufSize = pool.SizeMedium
}
if bufSize > pool.SizeLarge {
bufSize = pool.SizeLarge
}
var wg sync.WaitGroup
wg.Add(2)
stopCh := make(chan struct{})
var closeOnce sync.Once
closeAll := func() {
closeOnce.Do(func() {
close(stopCh)
_ = a.Close()
_ = b.Close()
})
}
errCh := make(chan error, 2)
go func() {
defer wg.Done()
err := pipeBuffer(b, a, bufSize, onAToB, stopCh)
if err != nil {
errCh <- err
}
closeAll()
}()
go func() {
defer wg.Done()
err := pipeBuffer(a, b, bufSize, onBToA, stopCh)
if err != nil {
errCh <- err
}
closeAll()
}()
if ctx != nil {
go func() {
select {
case <-ctx.Done():
closeAll()
case <-stopCh:
}
}()
}
wg.Wait()
select {
case err := <-errCh:
return err
default:
return nil
}
}
func pipeBuffer(dst io.ReadWriteCloser, src io.ReadWriteCloser, bufSize int, onCopied func(n int64), stopCh <-chan struct{}) error {
bufPtr := pool.GetBuffer(bufSize)
defer pool.PutBuffer(bufPtr)
buf := (*bufPtr)[:bufSize]
_, err := copyBuffer(dst, src, buf, onCopied, stopCh)
if cr, ok := src.(closeReader); ok {
_ = cr.CloseRead()
}
if cw, ok := dst.(closeWriter); ok {
if e := cw.CloseWrite(); e != nil {
_ = dst.Close()
}
if rd, ok := dst.(readDeadliner); ok {
_ = rd.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
}
} else {
_ = dst.Close()
if rd, ok := dst.(readDeadliner); ok {
_ = rd.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
}
}
return err
}
func copyBuffer(dst io.Writer, src io.Reader, buf []byte, onCopied func(n int64), stopCh <-chan struct{}) (written int64, err error) {
for {
select {
case <-stopCh:
return written, io.EOF
default:
}
nr, er := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
if onCopied != nil {
onCopied(int64(nw))
}
}
if ew != nil {
return written, ew
}
if nr != nw {
return written, io.ErrShortWrite
}
}
if er != nil {
if er == io.EOF {
return written, nil
}
return written, er
}
}
}