feat(tcp): add transmission protocol control functionality

This commit is contained in:
Gouryella
2026-01-16 14:50:18 +08:00
parent 761c170642
commit bb1ba6d4b2
4 changed files with 36 additions and 11 deletions

View File

@@ -21,7 +21,7 @@
}
# Reverse proxy to drip-server (plain TCP mode)
reverse_proxy drip-server:8443 {
reverse_proxy host.docker.internal:8443 {
# Pass original host header
header_up Host {host}
header_up X-Real-IP {remote_host}

View File

@@ -14,15 +14,22 @@ services:
DOMAIN: ${DOMAIN}
ACME_EMAIL: ${ACME_EMAIL:-}
CF_API_TOKEN: ${CF_API_TOKEN}
extra_hosts:
- "host.docker.internal:host-gateway"
mem_limit: 256m
mem_reservation: 64m
drip-server:
image: driptunnel/drip-server:${VERSION:-latest}
container_name: drip-server
restart: unless-stopped
ports:
- "20000-20100:20000-20100"
network_mode: host
volumes:
- ./config.yaml:/app/config.yaml:ro
environment:
GOMEMLIMIT: 256MiB
mem_limit: 512m
mem_reservation: 128m
volumes:
caddy-data:

View File

@@ -61,6 +61,7 @@ type Connection struct {
// Server capabilities
allowedTunnelTypes []string
allowedTransports []string
}
// NewConnection creates a new connection handler
@@ -113,6 +114,12 @@ func (c *Connection) Handle() error {
return c.handleHTTPRequest(reader)
}
// Check if TCP transport is allowed (only for Drip protocol connections, not HTTP)
if !c.isTransportAllowed("tcp") {
c.logger.Warn("TCP transport not allowed, rejecting Drip protocol connection")
return fmt.Errorf("TCP transport not allowed")
}
frame, err := protocol.ReadFrame(reader)
if err != nil {
return fmt.Errorf("failed to read registration frame: %w", err)
@@ -767,6 +774,24 @@ func (c *Connection) SetAllowedTunnelTypes(types []string) {
c.allowedTunnelTypes = types
}
// SetAllowedTransports sets the allowed transports for this connection
func (c *Connection) SetAllowedTransports(transports []string) {
c.allowedTransports = transports
}
// isTransportAllowed checks if a transport is allowed
func (c *Connection) isTransportAllowed(transport string) bool {
if len(c.allowedTransports) == 0 {
return true
}
for _, t := range c.allowedTransports {
if strings.EqualFold(t, transport) {
return true
}
}
return false
}
// isTunnelTypeAllowed checks if a tunnel type is allowed
func (c *Connection) isTunnelTypeAllowed(tunnelType string) bool {
if len(c.allowedTunnelTypes) == 0 {

View File

@@ -207,14 +207,6 @@ func (l *Listener) handleConnection(netConn net.Conn) {
l.connMu.Unlock()
})
// Check if TCP transport is allowed
if !l.IsTransportAllowed("tcp") {
l.logger.Warn("TCP transport not allowed, rejecting connection",
zap.String("remote_addr", netConn.RemoteAddr().String()),
)
return
}
// Handle TLS connections
if tlsConn, ok := netConn.(*tls.Conn); ok {
if err := tlsConn.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
@@ -279,6 +271,7 @@ func (l *Listener) handleConnection(netConn net.Conn) {
conn := NewConnection(netConn, l.authToken, l.manager, l.logger, l.portAlloc, l.domain, l.tunnelDomain, l.publicPort, l.httpHandler, l.groupManager, l.httpListener)
conn.SetAllowedTunnelTypes(l.allowedTunnelTypes)
conn.SetAllowedTransports(l.allowedTransports)
connID := netConn.RemoteAddr().String()
l.connMu.Lock()