mirror of
https://github.com/Gouryella/drip.git
synced 2026-02-23 21:00:44 +00:00
- Implement IP whitelist/blacklist access control mechanism - Add --allow-ip and --deny-ip command-line arguments to configure IP access rules - Support CIDR format for IP range configuration - Enable IP access control in HTTP, HTTPS, and TCP tunnels - Add IP access check logic to server-side proxy handling - Update documentation to explain how to use IP access control
260 lines
4.7 KiB
Go
260 lines
4.7 KiB
Go
package tcp
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"drip/internal/shared/netutil"
|
|
"drip/internal/shared/pool"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// Proxy exposes a public TCP port and forwards each incoming
|
|
// connection over a dedicated mux stream.
|
|
type Proxy struct {
|
|
port int
|
|
subdomain string
|
|
logger *zap.Logger
|
|
|
|
listener net.Listener
|
|
stopCh chan struct{}
|
|
once sync.Once
|
|
wg sync.WaitGroup
|
|
|
|
openStream func() (net.Conn, error)
|
|
stats trafficStats
|
|
sem chan struct{}
|
|
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
|
|
checkIPAccess func(ip string) bool
|
|
}
|
|
|
|
type trafficStats interface {
|
|
AddBytesIn(n int64)
|
|
AddBytesOut(n int64)
|
|
IncActiveConnections()
|
|
DecActiveConnections()
|
|
}
|
|
|
|
func NewProxy(ctx context.Context, port int, subdomain string, openStream func() (net.Conn, error), stats trafficStats, logger *zap.Logger) *Proxy {
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
cctx, cancel := context.WithCancel(ctx)
|
|
|
|
const maxConcurrentConnections = 10000
|
|
var sem chan struct{}
|
|
if maxConcurrentConnections > 0 {
|
|
sem = make(chan struct{}, maxConcurrentConnections)
|
|
}
|
|
|
|
return &Proxy{
|
|
port: port,
|
|
subdomain: subdomain,
|
|
logger: logger,
|
|
stopCh: make(chan struct{}),
|
|
openStream: openStream,
|
|
stats: stats,
|
|
sem: sem,
|
|
ctx: cctx,
|
|
cancel: cancel,
|
|
}
|
|
}
|
|
|
|
// SetIPAccessCheck sets the IP access control check function.
|
|
func (p *Proxy) SetIPAccessCheck(check func(ip string) bool) {
|
|
p.checkIPAccess = check
|
|
}
|
|
|
|
func (p *Proxy) Start() error {
|
|
addr := fmt.Sprintf("0.0.0.0:%d", p.port)
|
|
|
|
ln, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
|
|
}
|
|
p.listener = ln
|
|
|
|
p.logger.Info("TCP proxy started",
|
|
zap.Int("port", p.port),
|
|
zap.String("subdomain", p.subdomain),
|
|
)
|
|
|
|
p.wg.Add(1)
|
|
go p.acceptLoop()
|
|
return nil
|
|
}
|
|
|
|
func (p *Proxy) Stop() {
|
|
p.once.Do(func() {
|
|
close(p.stopCh)
|
|
p.cancel()
|
|
|
|
if p.listener != nil {
|
|
_ = p.listener.Close()
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
p.wg.Wait()
|
|
close(done)
|
|
}()
|
|
|
|
const stopTimeout = 30 * time.Second
|
|
|
|
select {
|
|
case <-done:
|
|
p.logger.Info("TCP proxy stopped",
|
|
zap.Int("port", p.port),
|
|
zap.String("subdomain", p.subdomain),
|
|
)
|
|
case <-time.After(stopTimeout):
|
|
p.logger.Warn("TCP proxy stop timed out",
|
|
zap.Int("port", p.port),
|
|
zap.String("subdomain", p.subdomain),
|
|
zap.Duration("timeout", stopTimeout),
|
|
)
|
|
}
|
|
})
|
|
}
|
|
|
|
func (p *Proxy) acceptLoop() {
|
|
defer p.wg.Done()
|
|
|
|
tcpLn, _ := p.listener.(*net.TCPListener)
|
|
|
|
for {
|
|
select {
|
|
case <-p.stopCh:
|
|
return
|
|
default:
|
|
}
|
|
|
|
if tcpLn != nil {
|
|
_ = tcpLn.SetDeadline(time.Now().Add(1 * time.Second))
|
|
}
|
|
|
|
conn, err := p.listener.Accept()
|
|
if err != nil {
|
|
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
|
continue
|
|
}
|
|
select {
|
|
case <-p.stopCh:
|
|
return
|
|
default:
|
|
continue
|
|
}
|
|
}
|
|
|
|
p.wg.Add(1)
|
|
go p.handleConn(conn)
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) handleConn(conn net.Conn) {
|
|
defer p.wg.Done()
|
|
defer conn.Close()
|
|
|
|
if p.checkIPAccess != nil {
|
|
clientIP := netutil.ExtractIP(conn.RemoteAddr().String())
|
|
if !p.checkIPAccess(clientIP) {
|
|
p.logger.Debug("IP access denied",
|
|
zap.String("ip", clientIP),
|
|
zap.Int("port", p.port),
|
|
)
|
|
return
|
|
}
|
|
}
|
|
|
|
if p.sem != nil {
|
|
select {
|
|
case p.sem <- struct{}{}:
|
|
defer func() { <-p.sem }()
|
|
default:
|
|
return
|
|
}
|
|
}
|
|
|
|
if p.stats != nil {
|
|
p.stats.IncActiveConnections()
|
|
defer p.stats.DecActiveConnections()
|
|
}
|
|
|
|
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
|
_ = tcpConn.SetNoDelay(true)
|
|
_ = tcpConn.SetKeepAlive(true)
|
|
_ = tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
|
_ = tcpConn.SetReadBuffer(256 * 1024)
|
|
_ = tcpConn.SetWriteBuffer(256 * 1024)
|
|
}
|
|
|
|
if p.openStream == nil {
|
|
return
|
|
}
|
|
|
|
const openStreamTimeout = 3 * time.Second
|
|
type streamResult struct {
|
|
stream net.Conn
|
|
err error
|
|
}
|
|
resultCh := make(chan streamResult, 1)
|
|
|
|
ctx, cancel := context.WithTimeout(p.ctx, openStreamTimeout)
|
|
defer cancel()
|
|
|
|
go func() {
|
|
s, err := p.openStream()
|
|
select {
|
|
case resultCh <- streamResult{s, err}:
|
|
case <-ctx.Done():
|
|
if s != nil {
|
|
s.Close()
|
|
}
|
|
}
|
|
}()
|
|
|
|
var stream net.Conn
|
|
select {
|
|
case result := <-resultCh:
|
|
if result.err != nil {
|
|
if !errors.Is(result.err, net.ErrClosed) {
|
|
p.logger.Debug("Open stream failed", zap.Error(result.err))
|
|
}
|
|
return
|
|
}
|
|
stream = result.stream
|
|
case <-ctx.Done():
|
|
p.logger.Debug("Open stream timeout")
|
|
return
|
|
case <-p.stopCh:
|
|
return
|
|
}
|
|
|
|
defer stream.Close()
|
|
|
|
_ = netutil.PipeWithCallbacksAndBufferSize(
|
|
p.ctx,
|
|
conn,
|
|
stream,
|
|
pool.SizeLarge,
|
|
func(n int64) {
|
|
if p.stats != nil {
|
|
p.stats.AddBytesIn(n)
|
|
}
|
|
},
|
|
func(n int64) {
|
|
if p.stats != nil {
|
|
p.stats.AddBytesOut(n)
|
|
}
|
|
},
|
|
)
|
|
}
|