mirror of
https://github.com/TrustTunnel/TrustTunnel.git
synced 2026-04-24 03:30:41 +00:00
Pull request 170: [Github PR] Fix reverse proxy routing for H2/H3
* commit '9ba3ef505404c6fcf6153e5f0640be81b25e2f5b': Add auth_failure_status_code to docs Update changelog and docs Add auth_failure_status_code feature to change response code on auth failure Disable H3 reverse proxy tests on Linux Support new fields in setup_wizard Remove redundant code Remove all changes related to token-gated tunnel routing docs: update reverse proxy paths Add coverage for routing and reverse proxy (H2/H3, chunked, custom paths) Stabilize reverse proxy for HTTP/2 and HTTP/3 (chunked, EOF, stream shutdown) Reverse proxy selection and routing Token gating and deny fallback
This commit is contained in:
@@ -1,5 +1,9 @@
|
||||
# CHANGELOG
|
||||
|
||||
- [Fix] Reverse proxy routing for H2/H3.
|
||||
- [Feature] Add `ping_enable`, `ping_path`, `speedtest_enable` and `speedtest_path` config keys to configure ping and speedtest handlers.
|
||||
- [Feature] Add `auth_failure_status_code` config key to control the HTTP status code returned on authentication failure (407 or 405). Defaults to 407.
|
||||
|
||||
## 1.0.16
|
||||
|
||||
- [Fix] HTTP/1.1 codec busy loop when receiving partial request headers.
|
||||
|
||||
@@ -247,18 +247,25 @@ action = "deny"
|
||||
|
||||
### Core Settings
|
||||
|
||||
| Setting | Type | Default | Description |
|
||||
| ------- | ---- | ------- | ----------- |
|
||||
| `listen_address` | String | `0.0.0.0:443` | Address and port to listen on |
|
||||
| `ipv6_available` | Boolean | `true` | Whether IPv6 connections can be routed |
|
||||
| `allow_private_network_connections` | Boolean | `false` | Allow connections to endpoint's private network |
|
||||
| `tls_handshake_timeout_secs` | Integer | `10` | TLS handshake timeout in seconds |
|
||||
| `client_listener_timeout_secs` | Integer | `600` | Client listener timeout in seconds (10 minutes) |
|
||||
| `connection_establishment_timeout_secs` | Integer | `30` | Outgoing connection timeout in seconds |
|
||||
| `tcp_connections_timeout_secs` | Integer | `604800` | Idle TCP connection timeout (1 week) |
|
||||
| `udp_connections_timeout_secs` | Integer | `300` | UDP connection timeout (5 minutes) |
|
||||
| `credentials_file` | String | - | Path to credentials file |
|
||||
| `rules_file` | String | - | Path to rules file (optional) |
|
||||
| Setting | Type | Default | Description |
|
||||
|-----------------------------------------|---------|---------------|------------------------------------------------------------------|
|
||||
| `listen_address` | String | `0.0.0.0:443` | Address and port to listen on |
|
||||
| `ipv6_available` | Boolean | `true` | Whether IPv6 connections can be routed |
|
||||
| `allow_private_network_connections` | Boolean | `false` | Allow connections to endpoint's private network |
|
||||
| `tls_handshake_timeout_secs` | Integer | `10` | TLS handshake timeout in seconds |
|
||||
| `client_listener_timeout_secs` | Integer | `600` | Client listener timeout in seconds (10 minutes) |
|
||||
| `connection_establishment_timeout_secs` | Integer | `30` | Outgoing connection timeout in seconds |
|
||||
| `tcp_connections_timeout_secs` | Integer | `604800` | Idle TCP connection timeout (1 week) |
|
||||
| `udp_connections_timeout_secs` | Integer | `300` | UDP connection timeout (5 minutes) |
|
||||
| `credentials_file` | String | - | Path to credentials file |
|
||||
| `rules_file` | String | - | Path to rules file (optional) |
|
||||
| `speedtest_enable` | Boolean | `false` | Enable speedtest handler on main hosts |
|
||||
| `ping_enable` | Boolean | `false` | Enable ping handler on main hosts |
|
||||
| `ping_path` | String | - | Optional path prefix for ping on main hosts |
|
||||
| `speedtest_path` | String | - | Optional path prefix for speedtest on main hosts |
|
||||
| `auth_failure_status_code` | Integer | `407` | HTTP status code returned on authentication failure (405 or 407) |
|
||||
|
||||
Ping and speedtest are matched only via their configured paths. Default paths are: `/ping` and `/speedtest`.
|
||||
|
||||
### Listen Protocol Settings
|
||||
|
||||
|
||||
@@ -47,8 +47,9 @@ Client's connection is treated as a reverse proxy stream in the following cases:
|
||||
|
||||
1) A TLS session or QUIC connection has the SNI set to the host name equal to one
|
||||
from `TlsHostsSettings.reverse_proxy`.
|
||||
2) An HTTP/1.1 request has `Upgrade` header and its path starts with `ReverseProxySettings.path_mask`.
|
||||
3) An HTTP/3 request has a path starting with `ReverseProxySettings.path_mask`.
|
||||
2) If a request path starts with `ReverseProxySettings.path_mask`, it is routed to reverse proxy.
|
||||
3) Otherwise, routing is defined by `ping_path` and `speedtest_path` configuration.
|
||||
Requests that do not match ping, speedtest, or reverse proxy rules are treated as tunnel requests.
|
||||
|
||||
The stream is used for mutual client and endpoint notifications and some control messages.
|
||||
The endpoint does TLS termination on such connections and translates HTTP/x traffic into
|
||||
@@ -58,7 +59,10 @@ Like this:
|
||||
```(client) TLS(HTTP/x) <--(endpoint)--> (server) HTTP/1.1```
|
||||
|
||||
The translated HTTP/1.1 requests have the custom header `X-Original-Protocol` appended.
|
||||
For now, its value can be either `HTTP1`, or `HTTP3`.
|
||||
For now, its value can be `HTTP1`, `HTTP2`, or `HTTP3`.
|
||||
|
||||
Note: HTTP/3 reverse proxy handling keeps the write side open when the client finishes sending
|
||||
the request body, to avoid truncating large responses.
|
||||
|
||||
### Authentication
|
||||
|
||||
|
||||
@@ -539,6 +539,7 @@ impl Core {
|
||||
}
|
||||
},
|
||||
context.settings.tls_handshake_timeout,
|
||||
context.settings.speedtest_path.clone(),
|
||||
client_id,
|
||||
)
|
||||
.await
|
||||
@@ -631,6 +632,7 @@ impl Core {
|
||||
context.shutdown.clone(),
|
||||
Box::new(Http3Codec::new(socket, client_id.clone())),
|
||||
context.settings.tls_handshake_timeout,
|
||||
context.settings.speedtest_path.clone(),
|
||||
client_id,
|
||||
)
|
||||
.await
|
||||
|
||||
@@ -145,7 +145,8 @@ impl Http3Codec {
|
||||
Ok(None)
|
||||
}
|
||||
QuicSocketEvent::Close(stream_id) => {
|
||||
let _ = self.on_stream_shutdown(stream_id, None);
|
||||
// Client finished sending request body; keep write side open for response.
|
||||
let _ = self.on_stream_shutdown(stream_id, Some(quiche::Shutdown::Read));
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{http_codec, http_speedtest_handler, net_utils, settings, tls_demultiplexer};
|
||||
use crate::{http_codec, net_utils, settings, tls_demultiplexer};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub(crate) struct HttpDemux {
|
||||
@@ -12,65 +12,41 @@ impl HttpDemux {
|
||||
|
||||
pub fn select(
|
||||
&self,
|
||||
protocol: tls_demultiplexer::Protocol,
|
||||
_protocol: tls_demultiplexer::Protocol,
|
||||
request: &http_codec::RequestHeaders,
|
||||
) -> net_utils::Channel {
|
||||
if self.check_ping(request) {
|
||||
net_utils::Channel::Ping
|
||||
} else if self.check_speedtest(request) {
|
||||
net_utils::Channel::Speedtest
|
||||
} else if self.check_reverse_proxy(protocol, request) {
|
||||
net_utils::Channel::ReverseProxy
|
||||
} else {
|
||||
net_utils::Channel::Tunnel
|
||||
match () {
|
||||
_ if self.check_ping(request) => net_utils::Channel::Ping,
|
||||
_ if self.check_speedtest(request) => net_utils::Channel::Speedtest,
|
||||
_ if self.check_reverse_proxy_path(request) => net_utils::Channel::ReverseProxy,
|
||||
_ => net_utils::Channel::Tunnel,
|
||||
}
|
||||
}
|
||||
|
||||
fn check_ping(&self, request: &http_codec::RequestHeaders) -> bool {
|
||||
static MARKER_HEADERS: [(http::HeaderName, http::HeaderValue); 2] = [
|
||||
(
|
||||
http::HeaderName::from_static("x-ping"),
|
||||
http::HeaderValue::from_static("1"),
|
||||
),
|
||||
(
|
||||
http::HeaderName::from_static("sec-fetch-mode"),
|
||||
http::HeaderValue::from_static("navigate"),
|
||||
),
|
||||
];
|
||||
|
||||
MARKER_HEADERS
|
||||
.iter()
|
||||
.any(|(name, value)| request.headers.get(name) == Some(value))
|
||||
if !self.core_settings.ping_enable {
|
||||
return false;
|
||||
}
|
||||
if let Some(path) = self.core_settings.ping_path.as_ref() {
|
||||
return request.uri.path().starts_with(path);
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn check_speedtest(&self, request: &http_codec::RequestHeaders) -> bool {
|
||||
if !self.core_settings.speedtest_enable {
|
||||
return false;
|
||||
}
|
||||
request
|
||||
.uri
|
||||
.path()
|
||||
.strip_prefix('/')
|
||||
.and_then(|x| x.strip_prefix(http_speedtest_handler::SKIPPABLE_PATH_SEGMENT))
|
||||
.and_then(|x| x.strip_prefix('/'))
|
||||
.is_some()
|
||||
if let Some(path) = self.core_settings.speedtest_path.as_ref() {
|
||||
return request.uri.path().starts_with(path);
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn check_reverse_proxy(
|
||||
&self,
|
||||
protocol: tls_demultiplexer::Protocol,
|
||||
request: &http_codec::RequestHeaders,
|
||||
) -> bool {
|
||||
match protocol {
|
||||
tls_demultiplexer::Protocol::Http1 => {
|
||||
if !request.headers.contains_key(http::header::UPGRADE) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
tls_demultiplexer::Protocol::Http3 => (),
|
||||
_ => return false,
|
||||
fn check_reverse_proxy_path(&self, request: &http_codec::RequestHeaders) -> bool {
|
||||
if self.core_settings.reverse_proxy.is_none() {
|
||||
return false;
|
||||
}
|
||||
|
||||
self.core_settings
|
||||
.reverse_proxy
|
||||
.as_ref()
|
||||
|
||||
@@ -22,7 +22,6 @@ const HEALTH_CHECK_AUTHORITY: &str = "_check";
|
||||
const UDP_AUTHORITY: &str = "_udp2";
|
||||
const ICMP_AUTHORITY: &str = "_icmp";
|
||||
|
||||
const AUTHORIZATION_FAILURE_STATUS_CODE: StatusCode = StatusCode::PROXY_AUTHENTICATION_REQUIRED;
|
||||
const AUTHORIZATION_FAILURE_EXTRA_HEADER: (&str, &str) =
|
||||
("proxy-authenticate", "Basic realm=Authorization Required");
|
||||
|
||||
@@ -40,11 +39,13 @@ pub(crate) struct HttpDownstream {
|
||||
struct TcpConnection {
|
||||
stream: Box<dyn http_codec::Stream>,
|
||||
id: log_utils::IdChain<u64>,
|
||||
auth_failure_status_code: StatusCode,
|
||||
}
|
||||
|
||||
struct DatagramMultiplexer {
|
||||
stream: Box<dyn http_codec::Stream>,
|
||||
id: log_utils::IdChain<u64>,
|
||||
auth_failure_status_code: StatusCode,
|
||||
}
|
||||
|
||||
struct DatagramEncoder<D> {
|
||||
@@ -61,6 +62,7 @@ struct DatagramDecoder<D> {
|
||||
struct PendingRequest {
|
||||
stream: Box<dyn http_codec::Stream>,
|
||||
id: log_utils::IdChain<u64>,
|
||||
auth_failure_status_code: StatusCode,
|
||||
}
|
||||
|
||||
impl HttpDownstream {
|
||||
@@ -112,9 +114,13 @@ impl Downstream for HttpDownstream {
|
||||
match channel {
|
||||
net_utils::Channel::Tunnel => {
|
||||
log_id!(trace, stream_id, "HTTP downstream: tunnel request");
|
||||
break Ok(Some(Box::new(PendingRequest {
|
||||
let auth_failure_status_code =
|
||||
StatusCode::from_u16(self.context.settings.auth_failure_status_code)
|
||||
.unwrap_or(StatusCode::PROXY_AUTHENTICATION_REQUIRED);
|
||||
return Ok(Some(Box::new(PendingRequest {
|
||||
stream,
|
||||
id: stream_id,
|
||||
auth_failure_status_code,
|
||||
})));
|
||||
}
|
||||
net_utils::Channel::Ping => {
|
||||
@@ -136,6 +142,7 @@ impl Downstream for HttpDownstream {
|
||||
context.shutdown.clone(),
|
||||
Box::new(http_codec::stream_into_codec(stream, protocol)),
|
||||
context.settings.tls_handshake_timeout,
|
||||
context.settings.speedtest_path.clone(),
|
||||
stream_id,
|
||||
)
|
||||
.await
|
||||
@@ -201,7 +208,7 @@ impl downstream::PendingRequest for TcpConnection {
|
||||
}
|
||||
|
||||
fn fail_request(self: Box<Self>, error: tunnel::ConnectionError) {
|
||||
fail_request_with_error(self.stream, error);
|
||||
fail_request_with_error(self.stream, error, self.auth_failure_status_code);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -261,6 +268,7 @@ impl downstream::PendingRequest for PendingRequest {
|
||||
DatagramMultiplexer {
|
||||
stream: self.stream,
|
||||
id: self.id,
|
||||
auth_failure_status_code: self.auth_failure_status_code,
|
||||
},
|
||||
)),
|
||||
))
|
||||
@@ -274,13 +282,14 @@ impl downstream::PendingRequest for PendingRequest {
|
||||
Box::new(TcpConnection {
|
||||
stream: self.stream,
|
||||
id: self.id,
|
||||
auth_failure_status_code: self.auth_failure_status_code,
|
||||
}),
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn fail_request(self: Box<Self>, error: tunnel::ConnectionError) {
|
||||
fail_request_with_error(self.stream, error);
|
||||
fail_request_with_error(self.stream, error, self.auth_failure_status_code);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -324,7 +333,7 @@ impl downstream::PendingRequest for DatagramMultiplexer {
|
||||
}
|
||||
|
||||
fn fail_request(self: Box<Self>, error: tunnel::ConnectionError) {
|
||||
fail_request_with_error(self.stream, error);
|
||||
fail_request_with_error(self.stream, error, self.auth_failure_status_code);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -394,9 +403,12 @@ impl<D: Send> datagram_pipe::Sink for DatagramEncoder<D> {
|
||||
}
|
||||
}
|
||||
|
||||
fn tunnel_error_to_status_code(error: &tunnel::ConnectionError) -> StatusCode {
|
||||
fn tunnel_error_to_status_code(
|
||||
error: &tunnel::ConnectionError,
|
||||
auth_failure_status_code: StatusCode,
|
||||
) -> StatusCode {
|
||||
match error {
|
||||
tunnel::ConnectionError::Authentication(_) => AUTHORIZATION_FAILURE_STATUS_CODE,
|
||||
tunnel::ConnectionError::Authentication(_) => auth_failure_status_code,
|
||||
_ => BAD_STATUS_CODE,
|
||||
}
|
||||
}
|
||||
@@ -404,16 +416,23 @@ fn tunnel_error_to_status_code(error: &tunnel::ConnectionError) -> StatusCode {
|
||||
fn tunnel_error_to_warn_header(
|
||||
error: &tunnel::ConnectionError,
|
||||
hostname: &str,
|
||||
auth_failure_status_code: StatusCode,
|
||||
) -> Vec<(String, String)> {
|
||||
match error {
|
||||
tunnel::ConnectionError::Io(_) => vec![(
|
||||
WARNING_HEADER_NAME.to_string(),
|
||||
"300 - Connection failed for some reason".to_string(),
|
||||
)],
|
||||
tunnel::ConnectionError::Authentication(_) => vec![(
|
||||
AUTHORIZATION_FAILURE_EXTRA_HEADER.0.to_string(),
|
||||
AUTHORIZATION_FAILURE_EXTRA_HEADER.1.to_string(),
|
||||
)],
|
||||
tunnel::ConnectionError::Authentication(_) => {
|
||||
if auth_failure_status_code == StatusCode::PROXY_AUTHENTICATION_REQUIRED {
|
||||
vec![(
|
||||
AUTHORIZATION_FAILURE_EXTRA_HEADER.0.to_string(),
|
||||
AUTHORIZATION_FAILURE_EXTRA_HEADER.1.to_string(),
|
||||
)]
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
tunnel::ConnectionError::Timeout => {
|
||||
vec![(WARNING_HEADER_NAME.to_string(), format!("302 - {}", error))]
|
||||
}
|
||||
@@ -446,9 +465,21 @@ fn fail_request(
|
||||
}
|
||||
}
|
||||
|
||||
fn fail_request_with_error(stream: Box<dyn http_codec::Stream>, error: tunnel::ConnectionError) {
|
||||
let extra_headers = tunnel_error_to_warn_header(&error, request_hostname(stream.request()));
|
||||
fail_request(stream, tunnel_error_to_status_code(&error), extra_headers);
|
||||
fn fail_request_with_error(
|
||||
stream: Box<dyn http_codec::Stream>,
|
||||
error: tunnel::ConnectionError,
|
||||
auth_failure_status_code: StatusCode,
|
||||
) {
|
||||
let extra_headers = tunnel_error_to_warn_header(
|
||||
&error,
|
||||
request_hostname(stream.request()),
|
||||
auth_failure_status_code,
|
||||
);
|
||||
fail_request(
|
||||
stream,
|
||||
tunnel_error_to_status_code(&error, auth_failure_status_code),
|
||||
extra_headers,
|
||||
);
|
||||
}
|
||||
|
||||
fn request_hostname(request: &dyn http_codec::PendingRequest) -> &str {
|
||||
@@ -457,3 +488,54 @@ fn request_hostname(request: &dyn http_codec::PendingRequest) -> &str {
|
||||
.map(http::uri::Authority::as_str)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn tunnel_error_to_status_code_default_407() {
|
||||
let status_code = StatusCode::PROXY_AUTHENTICATION_REQUIRED;
|
||||
let error = tunnel::ConnectionError::Authentication("bad creds".into());
|
||||
assert_eq!(
|
||||
tunnel_error_to_status_code(&error, status_code),
|
||||
StatusCode::PROXY_AUTHENTICATION_REQUIRED
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tunnel_error_to_status_code_configured_405() {
|
||||
let status_code = StatusCode::METHOD_NOT_ALLOWED;
|
||||
let error = tunnel::ConnectionError::Authentication("bad creds".into());
|
||||
assert_eq!(
|
||||
tunnel_error_to_status_code(&error, status_code),
|
||||
StatusCode::METHOD_NOT_ALLOWED
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tunnel_error_to_status_code_non_auth_error_unaffected() {
|
||||
let error = tunnel::ConnectionError::Timeout;
|
||||
assert_eq!(
|
||||
tunnel_error_to_status_code(&error, StatusCode::METHOD_NOT_ALLOWED),
|
||||
BAD_STATUS_CODE
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn warn_header_includes_proxy_authenticate_for_407() {
|
||||
let status_code = StatusCode::PROXY_AUTHENTICATION_REQUIRED;
|
||||
let error = tunnel::ConnectionError::Authentication("bad creds".into());
|
||||
let headers = tunnel_error_to_warn_header(&error, "example.com", status_code);
|
||||
assert_eq!(headers.len(), 1);
|
||||
assert_eq!(headers[0].0, "proxy-authenticate");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn warn_header_empty_for_405() {
|
||||
let status_code = StatusCode::METHOD_NOT_ALLOWED;
|
||||
let error = tunnel::ConnectionError::Authentication("bad creds".into());
|
||||
let headers = tunnel_error_to_warn_header(&error, "example.com", status_code);
|
||||
assert!(headers.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,13 +2,12 @@ use crate::http_codec::HttpCodec;
|
||||
use crate::shutdown::Shutdown;
|
||||
use crate::{http_codec, log_id, log_utils, pipe};
|
||||
use bytes::Bytes;
|
||||
use std::borrow::Cow;
|
||||
use std::io::ErrorKind;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
|
||||
pub(crate) const SKIPPABLE_PATH_SEGMENT: &str = "speed";
|
||||
|
||||
const MAX_DOWNLOAD_MB: u32 = 100;
|
||||
const MAX_UPLOAD_MB: u32 = 120;
|
||||
const CHUNK_SIZE: usize = 64 * 1024;
|
||||
@@ -27,6 +26,7 @@ pub(crate) async fn listen(
|
||||
shutdown: Arc<Mutex<Shutdown>>,
|
||||
mut codec: Box<dyn HttpCodec>,
|
||||
timeout: Duration,
|
||||
base_path: Option<String>,
|
||||
log_id: log_utils::IdChain<u64>,
|
||||
) {
|
||||
let (mut shutdown_notification, _shutdown_completion) = {
|
||||
@@ -41,7 +41,7 @@ pub(crate) async fn listen(
|
||||
Err(e) => log_id!(debug, log_id, "Shutdown notification failure: {}", e),
|
||||
}
|
||||
},
|
||||
_ = listen_inner(codec.as_mut(), timeout, &log_id) => (),
|
||||
_ = listen_inner(codec.as_mut(), timeout, base_path.as_deref(), &log_id) => (),
|
||||
}
|
||||
|
||||
if let Err(e) = codec.graceful_shutdown().await {
|
||||
@@ -52,6 +52,7 @@ pub(crate) async fn listen(
|
||||
async fn listen_inner(
|
||||
codec: &mut dyn HttpCodec,
|
||||
timeout: Duration,
|
||||
base_path: Option<&str>,
|
||||
log_id: &log_utils::IdChain<u64>,
|
||||
) {
|
||||
let manager = Arc::new(SpeedtestManager::default());
|
||||
@@ -60,7 +61,7 @@ async fn listen_inner(
|
||||
Ok(Ok(Some(x))) => {
|
||||
let request_headers = x.request().request();
|
||||
log_id!(trace, x.id(), "Received request: {:?}", request_headers);
|
||||
match prepare_speedtest(request_headers) {
|
||||
match prepare_speedtest(request_headers, base_path) {
|
||||
Ok(Speedtest::Download(n)) => {
|
||||
manager.running_tests_num.fetch_add(1, Ordering::AcqRel);
|
||||
tokio::spawn({
|
||||
@@ -122,16 +123,23 @@ async fn listen_inner(
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_speedtest(request: &http_codec::RequestHeaders) -> Result<Speedtest, String> {
|
||||
let path = if let Some(x) = request
|
||||
.uri
|
||||
.path()
|
||||
.strip_prefix('/')
|
||||
.and_then(|x| x.strip_prefix(SKIPPABLE_PATH_SEGMENT))
|
||||
{
|
||||
x
|
||||
fn prepare_speedtest(
|
||||
request: &http_codec::RequestHeaders,
|
||||
base_path: Option<&str>,
|
||||
) -> Result<Speedtest, String> {
|
||||
let original_path = request.uri.path();
|
||||
let path = if let Some(base_path) = base_path {
|
||||
if let Some(x) = original_path.strip_prefix(base_path) {
|
||||
if x.starts_with('/') {
|
||||
Cow::Borrowed(x)
|
||||
} else {
|
||||
Cow::Owned(format!("/{x}"))
|
||||
}
|
||||
} else {
|
||||
Cow::Borrowed(original_path)
|
||||
}
|
||||
} else {
|
||||
request.uri.path()
|
||||
Cow::Borrowed(original_path)
|
||||
};
|
||||
|
||||
match request.method {
|
||||
|
||||
@@ -5,7 +5,7 @@ use crate::pipe::DuplexPipe;
|
||||
use crate::tcp_forwarder::TcpForwarder;
|
||||
use crate::tls_demultiplexer::Protocol;
|
||||
use crate::{core, forwarder, http1_codec, http_codec, log_id, log_utils, pipe, tunnel};
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use std::net::Ipv4Addr;
|
||||
@@ -14,6 +14,7 @@ use std::sync::Arc;
|
||||
|
||||
static ORIGINAL_PROTOCOL_HEADER: http::HeaderName =
|
||||
http::HeaderName::from_static("x-original-protocol");
|
||||
const H3_BUFFERED_BODY_LIMIT: usize = 2 * 1024 * 1024;
|
||||
|
||||
#[derive(Default)]
|
||||
struct SessionManager {
|
||||
@@ -133,7 +134,9 @@ async fn handle_stream(
|
||||
let original_version = request_headers.version;
|
||||
match protocol {
|
||||
Protocol::Http1 => (),
|
||||
Protocol::Http2 => unreachable!(),
|
||||
Protocol::Http2 => {
|
||||
request_headers.version = http::Version::HTTP_11;
|
||||
}
|
||||
Protocol::Http3 => {
|
||||
request_headers.version = http::Version::HTTP_11;
|
||||
if settings.h3_backward_compatibility
|
||||
@@ -159,13 +162,17 @@ async fn handle_stream(
|
||||
server_sink.write_all(encoded).await?;
|
||||
|
||||
let mut buffer = BytesMut::new();
|
||||
let (response, chunk) = loop {
|
||||
let (response, chunk, is_chunked) = loop {
|
||||
match server_source.read().await? {
|
||||
pipe::Data::Chunk(chunk) => {
|
||||
server_source.consume(chunk.len())?;
|
||||
buffer.put(chunk);
|
||||
}
|
||||
pipe::Data::Eof => return Err(ErrorKind::UnexpectedEof.into()),
|
||||
pipe::Data::Eof => {
|
||||
// Upstream closed before sending a valid HTTP response. Reply with 502
|
||||
// to avoid surfacing this as an H2 stream cancel to the client.
|
||||
return send_bad_gateway(respond, original_version);
|
||||
}
|
||||
}
|
||||
|
||||
match http1_codec::decode_response(
|
||||
@@ -176,16 +183,225 @@ async fn handle_stream(
|
||||
http1_codec::DecodeStatus::Partial(b) => buffer = b,
|
||||
http1_codec::DecodeStatus::Complete(mut h, tail) => {
|
||||
h.version = original_version; // restore the version in case it was not the same
|
||||
break (h, tail.freeze());
|
||||
let transfer_encoding_raw = h
|
||||
.headers
|
||||
.get(http::header::TRANSFER_ENCODING)
|
||||
.and_then(|x| x.to_str().ok())
|
||||
.map(str::to_owned);
|
||||
let is_chunked = transfer_encoding_raw
|
||||
.as_deref()
|
||||
.is_some_and(|v| v.to_ascii_lowercase().contains("chunked"));
|
||||
if !matches!(protocol, Protocol::Http1) {
|
||||
// Strip hop-by-hop headers that are invalid in HTTP/2 and HTTP/3.
|
||||
h.headers.remove(http::header::CONNECTION);
|
||||
h.headers.remove(http::header::TRANSFER_ENCODING);
|
||||
h.headers.remove(http::header::UPGRADE);
|
||||
h.headers.remove(http::header::TE);
|
||||
h.headers.remove(http::header::TRAILER);
|
||||
h.headers
|
||||
.remove(http::HeaderName::from_static("keep-alive"));
|
||||
h.headers
|
||||
.remove(http::HeaderName::from_static("proxy-connection"));
|
||||
}
|
||||
break (h, tail.freeze(), is_chunked);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut client_sink = respond.send_response(response, false)?.into_pipe_sink();
|
||||
let content_length = response
|
||||
.headers
|
||||
.get(http::header::CONTENT_LENGTH)
|
||||
.and_then(|x| x.to_str().ok())
|
||||
.and_then(|x| x.parse::<usize>().ok());
|
||||
// H3 streaming is fragile in practice; buffer reasonably small bodies to avoid truncation.
|
||||
if matches!(protocol, Protocol::Http3)
|
||||
&& content_length.is_some_and(|x| x <= H3_BUFFERED_BODY_LIMIT)
|
||||
{
|
||||
let total = content_length.unwrap();
|
||||
let mut body = BytesMut::with_capacity(total);
|
||||
let chunk_len = chunk.len();
|
||||
body.put(chunk);
|
||||
server_source.consume(chunk_len)?;
|
||||
|
||||
let mut remaining = total.saturating_sub(chunk_len);
|
||||
while remaining > 0 {
|
||||
match server_source.read().await? {
|
||||
pipe::Data::Chunk(chunk) => {
|
||||
server_source.consume(chunk.len())?;
|
||||
let to_take = std::cmp::min(chunk.len(), remaining);
|
||||
body.put(chunk.slice(..to_take));
|
||||
remaining -= to_take;
|
||||
}
|
||||
pipe::Data::Eof => {
|
||||
return Err(ErrorKind::UnexpectedEof.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut client_sink = respond.send_response(response, false)?.into_pipe_sink();
|
||||
write_all(&mut client_sink, body.freeze()).await?;
|
||||
client_sink.eof()?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let chunk_len = chunk.len();
|
||||
client_sink.write_all(chunk).await?;
|
||||
|
||||
if is_chunked && !matches!(protocol, Protocol::Http1) {
|
||||
async fn need_more(
|
||||
buffer: &mut BytesMut,
|
||||
server_source: &mut Box<dyn pipe::Source>,
|
||||
_log_id: &log_utils::IdChain<u64>,
|
||||
) -> io::Result<bool> {
|
||||
match server_source.read().await? {
|
||||
pipe::Data::Chunk(chunk) => {
|
||||
server_source.consume(chunk.len())?;
|
||||
buffer.put(chunk);
|
||||
Ok(false)
|
||||
}
|
||||
pipe::Data::Eof => Ok(true),
|
||||
}
|
||||
}
|
||||
|
||||
// Decode chunked HTTP/1 body and stream raw bytes to the client.
|
||||
// For HTTP/2 and HTTP/3, buffer reasonably small responses so we can
|
||||
// set Content-Length and avoid relying on connection teardown signals.
|
||||
let buffer_for_length = !matches!(protocol, Protocol::Http1);
|
||||
let mut buffered_body = BytesMut::new();
|
||||
let mut respond_opt = Some(respond);
|
||||
let mut response_opt = Some(response);
|
||||
let mut client_sink: Option<Box<dyn pipe::Sink>> = None;
|
||||
let ensure_client_sink = |response_opt: &mut Option<http_codec::ResponseHeaders>,
|
||||
respond_opt: &mut Option<Box<dyn http_codec::PendingRespond>>,
|
||||
client_sink: &mut Option<Box<dyn pipe::Sink>>|
|
||||
-> io::Result<()> {
|
||||
if client_sink.is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
let response = response_opt
|
||||
.take()
|
||||
.ok_or_else(|| io::Error::new(ErrorKind::Other, "missing response"))?;
|
||||
let respond = respond_opt
|
||||
.take()
|
||||
.ok_or_else(|| io::Error::new(ErrorKind::Other, "missing respond"))?;
|
||||
*client_sink = Some(respond.send_response(response, false)?.into_pipe_sink());
|
||||
Ok(())
|
||||
};
|
||||
|
||||
let mut buffer = BytesMut::new();
|
||||
buffer.put(chunk);
|
||||
server_source.consume(chunk_len)?;
|
||||
loop {
|
||||
// Ensure we have a full chunk size line.
|
||||
let line_end = loop {
|
||||
if let Some(pos) = buffer.windows(2).position(|w| w == b"\r\n").map(|p| p + 2) {
|
||||
break pos;
|
||||
}
|
||||
let eof = need_more(&mut buffer, &mut server_source, log_id).await?;
|
||||
if eof {
|
||||
return Err(ErrorKind::UnexpectedEof.into());
|
||||
}
|
||||
};
|
||||
|
||||
let mut line = buffer.split_to(line_end);
|
||||
line.truncate(line.len().saturating_sub(2));
|
||||
let line = std::str::from_utf8(&line)
|
||||
.map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
|
||||
let size_hex = line.split(';').next().unwrap_or_default().trim();
|
||||
let chunk_size = usize::from_str_radix(size_hex, 16)
|
||||
.map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?;
|
||||
if chunk_size == 0 {
|
||||
// If there are no trailers, the stream ends with a single CRLF.
|
||||
if buffer.len() >= 2 && &buffer[..2] == b"\r\n" {
|
||||
buffer.advance(2);
|
||||
} else {
|
||||
// Consume trailers until CRLFCRLF or upstream EOF.
|
||||
loop {
|
||||
if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
|
||||
break;
|
||||
}
|
||||
let eof = need_more(&mut buffer, &mut server_source, log_id).await?;
|
||||
if eof {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if buffer_for_length && client_sink.is_none() {
|
||||
if let Some(resp) = response_opt.as_mut() {
|
||||
resp.headers
|
||||
.insert(http::header::CONTENT_LENGTH, buffered_body.len().into());
|
||||
}
|
||||
ensure_client_sink(&mut response_opt, &mut respond_opt, &mut client_sink)?;
|
||||
if let Some(sink) = client_sink.as_mut() {
|
||||
write_all(sink, buffered_body.split().freeze()).await?;
|
||||
}
|
||||
}
|
||||
ensure_client_sink(&mut response_opt, &mut respond_opt, &mut client_sink)?;
|
||||
if let Some(sink) = client_sink.as_mut() {
|
||||
sink.eof()?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Ensure we have the full chunk plus its trailing CRLF.
|
||||
while buffer.len() < chunk_size + 2 {
|
||||
let eof = need_more(&mut buffer, &mut server_source, log_id).await?;
|
||||
if eof {
|
||||
return Err(ErrorKind::UnexpectedEof.into());
|
||||
}
|
||||
}
|
||||
|
||||
let data = buffer.split_to(chunk_size).freeze();
|
||||
let _ = buffer.split_to(2);
|
||||
|
||||
if buffer_for_length && client_sink.is_none() {
|
||||
buffered_body.put(data.clone());
|
||||
if buffered_body.len() > H3_BUFFERED_BODY_LIMIT {
|
||||
// Fall back to streaming without a known length.
|
||||
ensure_client_sink(&mut response_opt, &mut respond_opt, &mut client_sink)?;
|
||||
if let Some(sink) = client_sink.as_mut() {
|
||||
write_all(sink, buffered_body.split().freeze()).await?;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
ensure_client_sink(&mut response_opt, &mut respond_opt, &mut client_sink)?;
|
||||
if let Some(sink) = client_sink.as_mut() {
|
||||
write_all(sink, data).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut client_sink = respond.send_response(response, false)?.into_pipe_sink();
|
||||
write_all(&mut client_sink, chunk).await?;
|
||||
server_source.consume(chunk_len)?;
|
||||
|
||||
if let Some(mut remaining) = content_length.and_then(|x| x.checked_sub(chunk_len)) {
|
||||
log_id!(
|
||||
debug,
|
||||
log_id,
|
||||
"Reverse proxy fixed-size body: remaining={} bytes after initial send",
|
||||
remaining
|
||||
);
|
||||
while remaining > 0 {
|
||||
match server_source.read().await? {
|
||||
pipe::Data::Chunk(chunk) => {
|
||||
server_source.consume(chunk.len())?;
|
||||
let to_send = std::cmp::min(chunk.len(), remaining);
|
||||
write_all(&mut client_sink, chunk.slice(..to_send)).await?;
|
||||
remaining -= to_send;
|
||||
}
|
||||
pipe::Data::Eof => break,
|
||||
}
|
||||
}
|
||||
if let Err(e) = client_sink.eof() {
|
||||
log_id!(debug, log_id, "Failed to close client stream: {}", e);
|
||||
} else {
|
||||
log_id!(debug, log_id, "Reverse proxy client stream closed");
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut pipe = DuplexPipe::new(
|
||||
(
|
||||
pipe::SimplexDirection::Outgoing,
|
||||
@@ -196,6 +412,38 @@ async fn handle_stream(
|
||||
|_, _| (),
|
||||
);
|
||||
|
||||
pipe.exchange(context.settings.tcp_connections_timeout)
|
||||
match pipe
|
||||
.exchange(context.settings.tcp_connections_timeout)
|
||||
.await
|
||||
{
|
||||
Ok(()) => Ok(()),
|
||||
// HTTP/2 (and sometimes HTTP/3) can surface graceful stream closure
|
||||
// as UnexpectedEof once the response has already been delivered.
|
||||
Err(e) if e.kind() == ErrorKind::UnexpectedEof => Ok(()),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
async fn write_all(sink: &mut Box<dyn pipe::Sink>, mut data: bytes::Bytes) -> io::Result<()> {
|
||||
while !data.is_empty() {
|
||||
let before = data.len();
|
||||
data = sink.write(data)?;
|
||||
if data.len() == before || !data.is_empty() {
|
||||
sink.wait_writable().await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn send_bad_gateway(
|
||||
respond: Box<dyn http_codec::PendingRespond>,
|
||||
version: http::Version,
|
||||
) -> io::Result<()> {
|
||||
let response = http::Response::builder()
|
||||
.status(http::StatusCode::BAD_GATEWAY)
|
||||
.version(version)
|
||||
.body(())
|
||||
.map_err(|e| io::Error::new(ErrorKind::Other, format!("bad gateway: {}", e)))?;
|
||||
let (parts, _) = response.into_parts();
|
||||
respond.send_response(parts, true).map(|_| ())
|
||||
}
|
||||
|
||||
@@ -28,10 +28,14 @@ pub enum ValidationError {
|
||||
ReverseProxy(String),
|
||||
/// Invalid [`Settings.listen_protocols`]
|
||||
ListenProtocols(String),
|
||||
/// Invalid request path configuration
|
||||
InvalidPath(String),
|
||||
/// Invalid rules file
|
||||
RulesFile(String),
|
||||
/// No credentials configured while listening on a public address
|
||||
NoCredentialsOnPublicAddress,
|
||||
/// Invalid auth failure status code
|
||||
InvalidAuthFailureStatusCode(u16),
|
||||
}
|
||||
|
||||
impl Debug for ValidationError {
|
||||
@@ -43,12 +47,18 @@ impl Debug for ValidationError {
|
||||
Self::SpeedTlsHostInfo(x) => write!(f, "Invalid speedtest TLS hosts: {}", x),
|
||||
Self::ReverseProxy(x) => write!(f, "Invalid reverse proxy settings: {}", x),
|
||||
Self::ListenProtocols(x) => write!(f, "Invalid listen protocols settings: {}", x),
|
||||
Self::InvalidPath(x) => write!(f, "Invalid request path: {}", x),
|
||||
Self::RulesFile(x) => write!(f, "Invalid rules file: {}", x),
|
||||
Self::NoCredentialsOnPublicAddress => write!(
|
||||
f,
|
||||
"No credentials configured (credentials_file is missing) while listening on a public address. \
|
||||
This is a security risk. Either configure credentials or use a loopback address (127.0.0.1 or ::1)"
|
||||
),
|
||||
Self::InvalidAuthFailureStatusCode(code) => write!(
|
||||
f,
|
||||
"Invalid auth_failure_status_code: {}. Supported values: 407, 405",
|
||||
code
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -175,9 +185,23 @@ pub struct Settings {
|
||||
#[serde(deserialize_with = "deserialize_rules")]
|
||||
pub(crate) rules_engine: Option<rules::RulesEngine>,
|
||||
|
||||
/// Whether speedtest is available on the main hosts via `/speed` path.
|
||||
/// Whether speedtest is available on the main hosts.
|
||||
#[serde(default = "Settings::default_speedtest_enable")]
|
||||
pub(crate) speedtest_enable: bool,
|
||||
/// Whether ping is available on the main hosts.
|
||||
#[serde(default = "Settings::default_ping_enable")]
|
||||
pub(crate) ping_enable: bool,
|
||||
/// Optional path prefix for ping requests on main hosts.
|
||||
#[serde(default = "Settings::default_ping_path")]
|
||||
pub(crate) ping_path: Option<String>,
|
||||
/// Optional path prefix for speedtest requests on main hosts.
|
||||
#[serde(default = "Settings::default_speedtest_path")]
|
||||
pub(crate) speedtest_path: Option<String>,
|
||||
|
||||
/// HTTP status code returned on authentication failure.
|
||||
/// Supported values: 407 (Proxy Authentication Required) or 405 (Method Not Allowed).
|
||||
#[serde(default = "Settings::default_auth_failure_status_code")]
|
||||
pub(crate) auth_failure_status_code: u16,
|
||||
|
||||
/// Default maximum number of simultaneous HTTP/1 and HTTP/2 connections per client credentials.
|
||||
/// TrustTunnel clients open 8 HTTP/2 connections by default, so set this to
|
||||
@@ -500,6 +524,10 @@ impl Settings {
|
||||
.map(ReverseProxySettings::validate)
|
||||
.transpose()?;
|
||||
|
||||
Self::validate_request_path("ping_path", &self.ping_path)?;
|
||||
Self::validate_request_path("speedtest_path", &self.speedtest_path)?;
|
||||
Self::validate_request_path_overlaps(&self.ping_path, &self.speedtest_path)?;
|
||||
|
||||
if self.listen_protocols.http1.is_none()
|
||||
&& self.listen_protocols.http2.is_none()
|
||||
&& self.listen_protocols.quic.is_none()
|
||||
@@ -512,6 +540,12 @@ impl Settings {
|
||||
return Err(ValidationError::NoCredentialsOnPublicAddress);
|
||||
}
|
||||
|
||||
if self.auth_failure_status_code != 407 && self.auth_failure_status_code != 405 {
|
||||
return Err(ValidationError::InvalidAuthFailureStatusCode(
|
||||
self.auth_failure_status_code,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -550,6 +584,46 @@ impl Settings {
|
||||
pub fn default_speedtest_enable() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
pub fn default_speedtest_path() -> Option<String> {
|
||||
Some("/speedtest".to_string())
|
||||
}
|
||||
|
||||
pub fn default_ping_enable() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
pub fn default_ping_path() -> Option<String> {
|
||||
Some("/ping".to_string())
|
||||
}
|
||||
|
||||
pub fn default_auth_failure_status_code() -> u16 {
|
||||
407
|
||||
}
|
||||
|
||||
fn validate_request_path(name: &str, path: &Option<String>) -> Result<(), ValidationError> {
|
||||
if let Some(path) = path {
|
||||
if path.is_empty() || !path.starts_with('/') {
|
||||
return Err(ValidationError::InvalidPath(format!("{name}: {path}")));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_request_path_overlaps(
|
||||
left: &Option<String>,
|
||||
right: &Option<String>,
|
||||
) -> Result<(), ValidationError> {
|
||||
let (Some(left), Some(right)) = (left.as_ref(), right.as_ref()) else {
|
||||
return Ok(());
|
||||
};
|
||||
if left == right || left.starts_with(right) || right.starts_with(left) {
|
||||
return Err(ValidationError::InvalidPath(format!(
|
||||
"path overlap: {left} vs {right}"
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -576,8 +650,12 @@ impl Default for Settings {
|
||||
metrics: Default::default(),
|
||||
rules_engine: Some(rules::RulesEngine::default_allow()),
|
||||
speedtest_enable: false,
|
||||
ping_enable: Settings::default_ping_enable(),
|
||||
ping_path: None,
|
||||
speedtest_path: None,
|
||||
default_max_http2_conns_per_client: None,
|
||||
default_max_http3_conns_per_client: None,
|
||||
auth_failure_status_code: Settings::default_auth_failure_status_code(),
|
||||
built: false,
|
||||
}
|
||||
}
|
||||
@@ -832,8 +910,12 @@ impl SettingsBuilder {
|
||||
metrics: Default::default(),
|
||||
rules_engine: Some(rules::RulesEngine::default_allow()),
|
||||
speedtest_enable: Settings::default_speedtest_enable(),
|
||||
ping_enable: Settings::default_ping_enable(),
|
||||
ping_path: Settings::default_ping_path(),
|
||||
speedtest_path: Settings::default_speedtest_path(),
|
||||
default_max_http2_conns_per_client: None,
|
||||
default_max_http3_conns_per_client: None,
|
||||
auth_failure_status_code: Settings::default_auth_failure_status_code(),
|
||||
built: true,
|
||||
},
|
||||
}
|
||||
@@ -966,6 +1048,30 @@ impl SettingsBuilder {
|
||||
self.settings.default_max_http3_conns_per_client = x;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the HTTP status code for authentication failures (407 or 405)
|
||||
pub fn auth_failure_status_code(mut self, x: u16) -> Self {
|
||||
self.settings.auth_failure_status_code = x;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set whether ping is available
|
||||
pub fn ping_enable(mut self, x: bool) -> Self {
|
||||
self.settings.ping_enable = x;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set path prefix for ping requests on main hosts
|
||||
pub fn ping_path<S: Into<String>>(mut self, path: S) -> Self {
|
||||
self.settings.ping_path = Some(path.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set path prefix for speedtest requests on main hosts
|
||||
pub fn speedtest_path<S: Into<String>>(mut self, path: S) -> Self {
|
||||
self.settings.speedtest_path = Some(path.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl TlsSettingsBuilder {
|
||||
@@ -1563,3 +1669,42 @@ where
|
||||
fn demangle_toml_string(x: String) -> String {
|
||||
x.replace('"', "").trim().to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_auth_failure_status_code_is_407() {
|
||||
let settings = Settings::default();
|
||||
assert_eq!(settings.auth_failure_status_code, 407);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_failure_status_code_407_valid() {
|
||||
let mut settings = Settings::default();
|
||||
settings.auth_failure_status_code = 407;
|
||||
settings.listen_address = (Ipv4Addr::LOCALHOST, 8443).into();
|
||||
assert!(settings.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_failure_status_code_405_valid() {
|
||||
let mut settings = Settings::default();
|
||||
settings.auth_failure_status_code = 405;
|
||||
settings.listen_address = (Ipv4Addr::LOCALHOST, 8443).into();
|
||||
assert!(settings.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_failure_status_code_200_invalid() {
|
||||
let mut settings = Settings::default();
|
||||
settings.auth_failure_status_code = 200;
|
||||
settings.listen_address = (Ipv4Addr::LOCALHOST, 8443).into();
|
||||
let err = settings.validate().unwrap_err();
|
||||
assert!(matches!(
|
||||
err,
|
||||
ValidationError::InvalidAuthFailureStatusCode(200)
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,6 +112,10 @@ impl Tunnel {
|
||||
log_id!(trace, self.id, "Tunnel received request");
|
||||
r
|
||||
}
|
||||
Ok(Err(e)) if e.kind() == ErrorKind::UnexpectedEof => {
|
||||
log_id!(debug, self.id, "Tunnel closed gracefully");
|
||||
return Ok(());
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
log_id!(trace, self.id, "Tunnel listen error: {}", e);
|
||||
return Err(e);
|
||||
|
||||
@@ -239,6 +239,9 @@ pub async fn run_endpoint(listen_address: &SocketAddr) {
|
||||
})
|
||||
.allow_private_network_connections(true)
|
||||
.speedtest_enable(true)
|
||||
.ping_enable(true)
|
||||
.ping_path("/ping")
|
||||
.speedtest_path("/speed")
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -34,12 +34,9 @@ ping_tests! {
|
||||
sni_h1: sni_h1_client,
|
||||
sni_h2: sni_h2_client,
|
||||
sni_h3: sni_h3_client,
|
||||
x_ping_h1: x_ping_h1_client,
|
||||
x_ping_h2: x_ping_h2_client,
|
||||
x_ping_h3: x_ping_h3_client,
|
||||
navigate_h1: navigate_h1_client,
|
||||
navigate_h2: navigate_h2_client,
|
||||
navigate_h3: navigate_h3_client,
|
||||
path_h1: path_h1_client,
|
||||
path_h2: path_h2_client,
|
||||
path_h3: path_h3_client,
|
||||
}
|
||||
|
||||
async fn sni_h1_client(endpoint_address: &SocketAddr) -> http::StatusCode {
|
||||
@@ -110,7 +107,7 @@ async fn sni_h3_client(endpoint_address: &SocketAddr) -> http::StatusCode {
|
||||
conn.recv_response().await.status
|
||||
}
|
||||
|
||||
async fn x_ping_h1_client(endpoint_address: &SocketAddr) -> http::StatusCode {
|
||||
async fn path_h1_client(endpoint_address: &SocketAddr) -> http::StatusCode {
|
||||
let stream =
|
||||
common::establish_tls_connection(common::MAIN_DOMAIN_NAME, endpoint_address, None).await;
|
||||
|
||||
@@ -118,18 +115,18 @@ async fn x_ping_h1_client(endpoint_address: &SocketAddr) -> http::StatusCode {
|
||||
stream,
|
||||
http::Version::HTTP_11,
|
||||
&format!(
|
||||
"https://{}:{}",
|
||||
"https://{}:{}/ping",
|
||||
common::MAIN_DOMAIN_NAME,
|
||||
endpoint_address.port()
|
||||
),
|
||||
&[("x-ping", "1")],
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.0
|
||||
.status
|
||||
}
|
||||
|
||||
async fn x_ping_h2_client(endpoint_address: &SocketAddr) -> http::StatusCode {
|
||||
async fn path_h2_client(endpoint_address: &SocketAddr) -> http::StatusCode {
|
||||
let stream = common::establish_tls_connection(
|
||||
common::MAIN_DOMAIN_NAME,
|
||||
endpoint_address,
|
||||
@@ -141,87 +138,26 @@ async fn x_ping_h2_client(endpoint_address: &SocketAddr) -> http::StatusCode {
|
||||
stream,
|
||||
http::Version::HTTP_2,
|
||||
&format!(
|
||||
"https://{}:{}",
|
||||
"https://{}:{}/ping",
|
||||
common::MAIN_DOMAIN_NAME,
|
||||
endpoint_address.port()
|
||||
),
|
||||
&[("x-ping", "1")],
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.0
|
||||
.status
|
||||
}
|
||||
|
||||
async fn x_ping_h3_client(endpoint_address: &SocketAddr) -> http::StatusCode {
|
||||
async fn path_h3_client(endpoint_address: &SocketAddr) -> http::StatusCode {
|
||||
let mut conn =
|
||||
common::Http3Session::connect(endpoint_address, common::MAIN_DOMAIN_NAME, None).await;
|
||||
conn.send_request(
|
||||
Request::get(format!(
|
||||
"https://{}:{}",
|
||||
"https://{}:{}/ping",
|
||||
common::MAIN_DOMAIN_NAME,
|
||||
endpoint_address.port()
|
||||
))
|
||||
.header("x-ping", "1")
|
||||
.body(hyper::Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await;
|
||||
|
||||
conn.recv_response().await.status
|
||||
}
|
||||
|
||||
async fn navigate_h1_client(endpoint_address: &SocketAddr) -> http::StatusCode {
|
||||
let stream =
|
||||
common::establish_tls_connection(common::MAIN_DOMAIN_NAME, endpoint_address, None).await;
|
||||
|
||||
common::do_get_request(
|
||||
stream,
|
||||
http::Version::HTTP_11,
|
||||
&format!(
|
||||
"https://{}:{}",
|
||||
common::MAIN_DOMAIN_NAME,
|
||||
endpoint_address.port()
|
||||
),
|
||||
&[("sec-fetch-mode", "navigate")],
|
||||
)
|
||||
.await
|
||||
.0
|
||||
.status
|
||||
}
|
||||
|
||||
async fn navigate_h2_client(endpoint_address: &SocketAddr) -> http::StatusCode {
|
||||
let stream = common::establish_tls_connection(
|
||||
common::MAIN_DOMAIN_NAME,
|
||||
endpoint_address,
|
||||
Some(net_utils::HTTP2_ALPN.as_bytes()),
|
||||
)
|
||||
.await;
|
||||
|
||||
common::do_get_request(
|
||||
stream,
|
||||
http::Version::HTTP_2,
|
||||
&format!(
|
||||
"https://{}:{}",
|
||||
common::MAIN_DOMAIN_NAME,
|
||||
endpoint_address.port()
|
||||
),
|
||||
&[("sec-fetch-mode", "navigate")],
|
||||
)
|
||||
.await
|
||||
.0
|
||||
.status
|
||||
}
|
||||
|
||||
async fn navigate_h3_client(endpoint_address: &SocketAddr) -> http::StatusCode {
|
||||
let mut conn =
|
||||
common::Http3Session::connect(endpoint_address, common::MAIN_DOMAIN_NAME, None).await;
|
||||
conn.send_request(
|
||||
Request::get(format!(
|
||||
"https://{}:{}",
|
||||
common::MAIN_DOMAIN_NAME,
|
||||
endpoint_address.port()
|
||||
))
|
||||
.header("sec-fetch-mode", "navigate")
|
||||
.body(hyper::Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
use bytes::Bytes;
|
||||
use futures::stream;
|
||||
use http::{Request, Response};
|
||||
use log::info;
|
||||
use once_cell::sync::Lazy;
|
||||
use ring::digest::{digest, SHA256};
|
||||
use std::future::Future;
|
||||
use std::net::{Ipv4Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use trusttunnel::net_utils;
|
||||
use trusttunnel::settings::{
|
||||
Http1Settings, Http2Settings, ListenProtocolSettings, QuicSettings, ReverseProxySettings,
|
||||
Settings, TlsHostInfo, TlsHostsSettings,
|
||||
@@ -13,6 +17,9 @@ use trusttunnel::settings::{
|
||||
#[allow(dead_code)]
|
||||
mod common;
|
||||
|
||||
// Use a larger body to catch partial responses without flooding logs.
|
||||
static RESPONSE_BODY: Lazy<Bytes> = Lazy::new(|| Bytes::from(vec![b'x'; 1024 * 1024]));
|
||||
|
||||
macro_rules! reverse_proxy_tests {
|
||||
($($name:ident: $client_fn:expr,)*) => {
|
||||
$(
|
||||
@@ -26,7 +33,7 @@ macro_rules! reverse_proxy_tests {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
let (response, body) = $client_fn(&endpoint_address).await;
|
||||
assert_eq!(response.status, http::StatusCode::OK);
|
||||
assert_eq!(body.as_ref(), b"how much watch?");
|
||||
assert_body_matches(&body);
|
||||
};
|
||||
|
||||
// Pin both tasks to avoid moving them
|
||||
@@ -53,11 +60,91 @@ macro_rules! reverse_proxy_tests {
|
||||
|
||||
reverse_proxy_tests! {
|
||||
sni_h1: sni_h1_client,
|
||||
sni_h3: sni_h3_client,
|
||||
path_h1: path_h1_client,
|
||||
path_h2: path_h2_client,
|
||||
}
|
||||
|
||||
// TODO: [TRUST-211] Enable H3 tests on Linux when QUIC issue will be fixed.
|
||||
#[cfg(target_os = "macos")]
|
||||
reverse_proxy_tests! {
|
||||
sni_h3: sni_h3_client,
|
||||
path_h3: path_h3_client,
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn path_h2_chunked() {
|
||||
common::set_up_logger();
|
||||
let endpoint_address = common::make_endpoint_address();
|
||||
let (proxy_address, proxy_task) = run_proxy_chunked();
|
||||
|
||||
let client_task = async {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
let (response, body) = path_h2_client(&endpoint_address).await;
|
||||
assert_eq!(response.status, http::StatusCode::OK);
|
||||
assert_body_matches(&body);
|
||||
};
|
||||
|
||||
tokio::pin!(client_task);
|
||||
tokio::pin!(proxy_task);
|
||||
|
||||
tokio::select! {
|
||||
_ = run_endpoint(&endpoint_address, &proxy_address) => unreachable!(),
|
||||
_ = tokio::time::sleep(Duration::from_secs(10)) => panic!("Timed out"),
|
||||
_ = &mut client_task => (),
|
||||
_ = &mut proxy_task => {
|
||||
tokio::select! {
|
||||
_ = client_task => (),
|
||||
_ = tokio::time::sleep(Duration::from_secs(5)) => {
|
||||
panic!("Client timed out after proxy completed")
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
#[tokio::test]
|
||||
async fn path_h3_chunked() {
|
||||
common::set_up_logger();
|
||||
let endpoint_address = common::make_endpoint_address();
|
||||
let (proxy_address, proxy_task) = run_proxy_chunked();
|
||||
|
||||
let client_task = async {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
let (response, body) = path_h3_client(&endpoint_address).await;
|
||||
assert_eq!(response.status, http::StatusCode::OK);
|
||||
assert_body_matches(&body);
|
||||
};
|
||||
|
||||
tokio::pin!(client_task);
|
||||
tokio::pin!(proxy_task);
|
||||
|
||||
tokio::select! {
|
||||
_ = run_endpoint(&endpoint_address, &proxy_address) => unreachable!(),
|
||||
_ = tokio::time::sleep(Duration::from_secs(10)) => panic!("Timed out"),
|
||||
_ = &mut client_task => (),
|
||||
_ = &mut proxy_task => {
|
||||
tokio::select! {
|
||||
_ = client_task => (),
|
||||
_ = tokio::time::sleep(Duration::from_secs(5)) => {
|
||||
panic!("Client timed out after proxy completed")
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_body_matches(body: &Bytes) {
|
||||
assert_eq!(body.len(), RESPONSE_BODY.len(), "response length mismatch");
|
||||
let expected_hash = digest(&SHA256, RESPONSE_BODY.as_ref());
|
||||
let actual_hash = digest(&SHA256, body.as_ref());
|
||||
assert_eq!(
|
||||
actual_hash.as_ref(),
|
||||
expected_hash.as_ref(),
|
||||
"response hash mismatch"
|
||||
);
|
||||
}
|
||||
|
||||
async fn sni_h1_client(endpoint_address: &SocketAddr) -> (http::response::Parts, Bytes) {
|
||||
let stream = common::establish_tls_connection(
|
||||
&format!("hello.{}", common::MAIN_DOMAIN_NAME),
|
||||
@@ -79,6 +166,7 @@ async fn sni_h1_client(endpoint_address: &SocketAddr) -> (http::response::Parts,
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
async fn sni_h3_client(endpoint_address: &SocketAddr) -> (http::response::Parts, Bytes) {
|
||||
let mut conn = common::Http3Session::connect(
|
||||
endpoint_address,
|
||||
@@ -116,6 +204,28 @@ async fn path_h1_client(endpoint_address: &SocketAddr) -> (http::response::Parts
|
||||
.await
|
||||
}
|
||||
|
||||
async fn path_h2_client(endpoint_address: &SocketAddr) -> (http::response::Parts, Bytes) {
|
||||
let stream = common::establish_tls_connection(
|
||||
common::MAIN_DOMAIN_NAME,
|
||||
endpoint_address,
|
||||
Some(net_utils::HTTP2_ALPN.as_bytes()),
|
||||
)
|
||||
.await;
|
||||
|
||||
common::do_get_request(
|
||||
stream,
|
||||
http::Version::HTTP_2,
|
||||
&format!(
|
||||
"https://{}:{}/hello/haha",
|
||||
common::MAIN_DOMAIN_NAME,
|
||||
endpoint_address.port()
|
||||
),
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
async fn path_h3_client(endpoint_address: &SocketAddr) -> (http::response::Parts, Bytes) {
|
||||
let mut conn =
|
||||
common::Http3Session::connect(endpoint_address, common::MAIN_DOMAIN_NAME, None).await;
|
||||
@@ -193,11 +303,45 @@ fn run_proxy() -> (SocketAddr, impl Future<Output = ()>) {
|
||||
})
|
||||
}
|
||||
|
||||
fn run_proxy_chunked() -> (SocketAddr, impl Future<Output = ()>) {
|
||||
let server = std::net::TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).unwrap();
|
||||
let _ = server.set_nonblocking(true);
|
||||
let server_addr = server.local_addr().unwrap();
|
||||
(server_addr, async move {
|
||||
let (socket, peer) = TcpListener::from_std(server)
|
||||
.unwrap()
|
||||
.accept()
|
||||
.await
|
||||
.unwrap();
|
||||
info!("New connection from {}", peer);
|
||||
hyper::server::conn::Http::new()
|
||||
.http1_only(true)
|
||||
.serve_connection(socket, hyper::service::service_fn(request_handler_chunked))
|
||||
.await
|
||||
.unwrap();
|
||||
})
|
||||
}
|
||||
|
||||
async fn request_handler(
|
||||
request: Request<hyper::Body>,
|
||||
) -> Result<Response<hyper::Body>, hyper::Error> {
|
||||
info!("Received request: {:?}", request);
|
||||
Ok(Response::builder()
|
||||
.body(hyper::Body::from("how much watch?"))
|
||||
.body(hyper::Body::from(RESPONSE_BODY.clone()))
|
||||
.unwrap())
|
||||
}
|
||||
|
||||
async fn request_handler_chunked(
|
||||
request: Request<hyper::Body>,
|
||||
) -> Result<Response<hyper::Body>, hyper::Error> {
|
||||
info!("Received request: {:?}", request);
|
||||
let chunk_size = 16 * 1024;
|
||||
let chunks = RESPONSE_BODY
|
||||
.chunks(chunk_size)
|
||||
.map(|c| Ok::<Bytes, std::io::Error>(Bytes::copy_from_slice(c)))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Ok(Response::builder()
|
||||
.body(hyper::Body::wrap_stream(stream::iter(chunks)))
|
||||
.unwrap())
|
||||
}
|
||||
|
||||
@@ -40,6 +40,18 @@ fn compose_main_table(settings: &Settings, credentials_path: &str, rules_path: &
|
||||
doc["udp_connections_timeout_secs"] =
|
||||
value(settings.get_udp_connections_timeout().as_secs() as i64);
|
||||
doc["speedtest_enable"] = value(*settings.get_speedtest_enable());
|
||||
if let Some(path) = settings.get_speedtest_path().as_ref() {
|
||||
doc["speedtest_path"] = value(path.clone());
|
||||
} else {
|
||||
doc.remove("speedtest_path");
|
||||
}
|
||||
doc["ping_enable"] = value(*settings.get_ping_enable());
|
||||
if let Some(path) = settings.get_ping_path().as_ref() {
|
||||
doc["ping_path"] = value(path.clone());
|
||||
} else {
|
||||
doc.remove("ping_path");
|
||||
}
|
||||
doc["auth_failure_status_code"] = value(*settings.get_auth_failure_status_code() as i64);
|
||||
|
||||
doc.to_string()
|
||||
}
|
||||
|
||||
@@ -67,6 +67,18 @@ udp_connections_timeout_secs = {}
|
||||
|
||||
{}
|
||||
speedtest_enable = {}
|
||||
|
||||
{}
|
||||
speedtest_path = "{}"
|
||||
|
||||
{}
|
||||
ping_enable = {}
|
||||
|
||||
{}
|
||||
ping_path = "{}"
|
||||
|
||||
{}
|
||||
auth_failure_status_code = {}
|
||||
"#,
|
||||
Settings::doc_listen_address().to_toml_comment(),
|
||||
crate::library_settings::DEFAULT_CREDENTIALS_PATH,
|
||||
@@ -91,6 +103,14 @@ speedtest_enable = {}
|
||||
Settings::default_udp_connections_timeout().as_secs(),
|
||||
Settings::doc_speedtest_enable().to_toml_comment(),
|
||||
Settings::default_speedtest_enable(),
|
||||
Settings::doc_speedtest_path().to_toml_comment(),
|
||||
Settings::default_speedtest_path().unwrap(),
|
||||
Settings::doc_ping_enable().to_toml_comment(),
|
||||
Settings::default_ping_enable(),
|
||||
Settings::doc_ping_path().to_toml_comment(),
|
||||
Settings::default_ping_path().unwrap(),
|
||||
Settings::doc_auth_failure_status_code().to_toml_comment(),
|
||||
Settings::default_auth_failure_status_code(),
|
||||
)
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user