Files
drip/internal/server/tcp/proxy.go
Gouryella 85a0f44e44 feat: Add IP access control functionality
- Implement IP whitelist/blacklist access control mechanism
- Add --allow-ip and --deny-ip command-line arguments to configure IP access rules
- Support CIDR format for IP range configuration
- Enable IP access control in HTTP, HTTPS, and TCP tunnels
- Add IP access check logic to server-side proxy handling
- Update documentation to explain how to use IP access control
2026-01-11 14:22:41 +08:00

260 lines
4.7 KiB
Go

package tcp
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"drip/internal/shared/netutil"
"drip/internal/shared/pool"
"go.uber.org/zap"
)
// Proxy exposes a public TCP port and forwards each incoming
// connection over a dedicated mux stream.
type Proxy struct {
port int
subdomain string
logger *zap.Logger
listener net.Listener
stopCh chan struct{}
once sync.Once
wg sync.WaitGroup
openStream func() (net.Conn, error)
stats trafficStats
sem chan struct{}
ctx context.Context
cancel context.CancelFunc
checkIPAccess func(ip string) bool
}
type trafficStats interface {
AddBytesIn(n int64)
AddBytesOut(n int64)
IncActiveConnections()
DecActiveConnections()
}
func NewProxy(ctx context.Context, port int, subdomain string, openStream func() (net.Conn, error), stats trafficStats, logger *zap.Logger) *Proxy {
if ctx == nil {
ctx = context.Background()
}
cctx, cancel := context.WithCancel(ctx)
const maxConcurrentConnections = 10000
var sem chan struct{}
if maxConcurrentConnections > 0 {
sem = make(chan struct{}, maxConcurrentConnections)
}
return &Proxy{
port: port,
subdomain: subdomain,
logger: logger,
stopCh: make(chan struct{}),
openStream: openStream,
stats: stats,
sem: sem,
ctx: cctx,
cancel: cancel,
}
}
// SetIPAccessCheck sets the IP access control check function.
func (p *Proxy) SetIPAccessCheck(check func(ip string) bool) {
p.checkIPAccess = check
}
func (p *Proxy) Start() error {
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
ln, err := net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
}
p.listener = ln
p.logger.Info("TCP proxy started",
zap.Int("port", p.port),
zap.String("subdomain", p.subdomain),
)
p.wg.Add(1)
go p.acceptLoop()
return nil
}
func (p *Proxy) Stop() {
p.once.Do(func() {
close(p.stopCh)
p.cancel()
if p.listener != nil {
_ = p.listener.Close()
}
done := make(chan struct{})
go func() {
p.wg.Wait()
close(done)
}()
const stopTimeout = 30 * time.Second
select {
case <-done:
p.logger.Info("TCP proxy stopped",
zap.Int("port", p.port),
zap.String("subdomain", p.subdomain),
)
case <-time.After(stopTimeout):
p.logger.Warn("TCP proxy stop timed out",
zap.Int("port", p.port),
zap.String("subdomain", p.subdomain),
zap.Duration("timeout", stopTimeout),
)
}
})
}
func (p *Proxy) acceptLoop() {
defer p.wg.Done()
tcpLn, _ := p.listener.(*net.TCPListener)
for {
select {
case <-p.stopCh:
return
default:
}
if tcpLn != nil {
_ = tcpLn.SetDeadline(time.Now().Add(1 * time.Second))
}
conn, err := p.listener.Accept()
if err != nil {
if ne, ok := err.(net.Error); ok && ne.Timeout() {
continue
}
select {
case <-p.stopCh:
return
default:
continue
}
}
p.wg.Add(1)
go p.handleConn(conn)
}
}
func (p *Proxy) handleConn(conn net.Conn) {
defer p.wg.Done()
defer conn.Close()
if p.checkIPAccess != nil {
clientIP := netutil.ExtractIP(conn.RemoteAddr().String())
if !p.checkIPAccess(clientIP) {
p.logger.Debug("IP access denied",
zap.String("ip", clientIP),
zap.Int("port", p.port),
)
return
}
}
if p.sem != nil {
select {
case p.sem <- struct{}{}:
defer func() { <-p.sem }()
default:
return
}
}
if p.stats != nil {
p.stats.IncActiveConnections()
defer p.stats.DecActiveConnections()
}
if tcpConn, ok := conn.(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
_ = tcpConn.SetKeepAlive(true)
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
_ = tcpConn.SetReadBuffer(256 * 1024)
_ = tcpConn.SetWriteBuffer(256 * 1024)
}
if p.openStream == nil {
return
}
const openStreamTimeout = 3 * time.Second
type streamResult struct {
stream net.Conn
err error
}
resultCh := make(chan streamResult, 1)
ctx, cancel := context.WithTimeout(p.ctx, openStreamTimeout)
defer cancel()
go func() {
s, err := p.openStream()
select {
case resultCh <- streamResult{s, err}:
case <-ctx.Done():
if s != nil {
s.Close()
}
}
}()
var stream net.Conn
select {
case result := <-resultCh:
if result.err != nil {
if !errors.Is(result.err, net.ErrClosed) {
p.logger.Debug("Open stream failed", zap.Error(result.err))
}
return
}
stream = result.stream
case <-ctx.Done():
p.logger.Debug("Open stream timeout")
return
case <-p.stopCh:
return
}
defer stream.Close()
_ = netutil.PipeWithCallbacksAndBufferSize(
p.ctx,
conn,
stream,
pool.SizeLarge,
func(n int64) {
if p.stats != nil {
p.stats.AddBytesIn(n)
}
},
func(n int64) {
if p.stats != nil {
p.stats.AddBytesOut(n)
}
},
)
}