Files
TrustTunnel/lib/tests/reverse_proxy.rs
Andrey Meshkov ddffd17835 Pull request 123: TRUST-241: Add logging and fix hanging connections
Squashed commit of the following:

commit 07e1a490f3e7a7f381de5f91c9645911aa4f3ada
Author: Andrey Meshkov <am@adguard.com>
Date:   Sat Dec 27 13:26:26 2025 +0300

    Fix tests

commit ed836634066c89934bac989e7785096f2a86b5c2
Author: Andrey Meshkov <am@adguard.com>
Date:   Fri Dec 26 21:03:36 2025 +0300

    Added cross-compile instructions

commit 497733e542bad09ae9606aa6fc0c244d755ddae5
Author: Andrey Meshkov <am@adguard.com>
Date:   Fri Dec 26 20:50:39 2025 +0300

    Added tcp keepalive

commit 8a4c9c2b56c35ade23ccbd1b6ec2413c4ad072ad
Author: Andrey Meshkov <am@adguard.com>
Date:   Fri Dec 26 20:37:19 2025 +0300

    Fix eof in tcp forwarder

commit 84af0fecbef3a2c96a63f4d0e2c5bbd6a97fa26d
Merge: 6adfe25 8d8addd
Author: Andrey Meshkov <am@adguard.com>
Date:   Fri Dec 26 18:30:03 2025 +0300

    Merge branch 'master' into add_logging

commit 6adfe2576824f18f9174c6d7e0c3a86b272d350f
Merge: f723e93 0daedac
Author: Andrey Meshkov <am@adguard.com>
Date:   Fri Dec 26 18:26:27 2025 +0300

    Merge branch 'fix/TRUST-231' into add_logging

commit f723e93bbd8f98afd686db3895621654946e49d0
Author: Andrey Meshkov <am@adguard.com>
Date:   Fri Dec 26 18:18:50 2025 +0300

    Added more logging

commit 0daedac118bcd6e4e85e8223ca120fc5be41f275
Author: Zhavoronkov Aleksei <a.zhavoronkov@adguard.com>
Date:   Thu Dec 25 13:49:49 2025 +0300

    Use tls_handshake_timeout for handshake
2025-12-29 12:41:55 +03:00

204 lines
6.3 KiB
Rust

use bytes::Bytes;
use http::{Request, Response};
use log::info;
use std::future::Future;
use std::net::{Ipv4Addr, SocketAddr};
use std::time::Duration;
use tokio::net::TcpListener;
use trusttunnel::settings::{
Http1Settings, Http2Settings, ListenProtocolSettings, QuicSettings, ReverseProxySettings,
Settings, TlsHostInfo, TlsHostsSettings,
};
#[allow(dead_code)]
mod common;
macro_rules! reverse_proxy_tests {
($($name:ident: $client_fn:expr,)*) => {
$(
#[tokio::test]
async fn $name() {
common::set_up_logger();
let endpoint_address = common::make_endpoint_address();
let (proxy_address, proxy_task) = run_proxy();
let client_task = async {
tokio::time::sleep(Duration::from_secs(1)).await;
let (response, body) = $client_fn(&endpoint_address).await;
assert_eq!(response.status, http::StatusCode::OK);
assert_eq!(body.as_ref(), b"how much watch?");
};
// Pin both tasks to avoid moving them
tokio::pin!(client_task);
tokio::pin!(proxy_task);
tokio::select! {
_ = run_endpoint(&endpoint_address, &proxy_address) => unreachable!(),
_ = tokio::time::sleep(Duration::from_secs(10)) => panic!("Timed out"),
// Wait for client_task first; if proxy_task completes, continue waiting for client
_ = &mut client_task => (),
_ = &mut proxy_task => {
// Proxy completed (expected after handling request), now wait for client
tokio::select! {
_ = client_task => (),
_ = tokio::time::sleep(Duration::from_secs(5)) => panic!("Client timed out after proxy completed"),
}
},
}
}
)*
}
}
reverse_proxy_tests! {
sni_h1: sni_h1_client,
sni_h3: sni_h3_client,
path_h1: path_h1_client,
path_h3: path_h3_client,
}
async fn sni_h1_client(endpoint_address: &SocketAddr) -> (http::response::Parts, Bytes) {
let stream = common::establish_tls_connection(
&format!("hello.{}", common::MAIN_DOMAIN_NAME),
endpoint_address,
None,
)
.await;
common::do_get_request(
stream,
http::Version::HTTP_11,
&format!(
"https://hello.{}:{}",
common::MAIN_DOMAIN_NAME,
endpoint_address.port()
),
&[],
)
.await
}
async fn sni_h3_client(endpoint_address: &SocketAddr) -> (http::response::Parts, Bytes) {
let mut conn = common::Http3Session::connect(
endpoint_address,
&format!("hello.{}", common::MAIN_DOMAIN_NAME),
None,
)
.await;
conn.exchange(
Request::get(format!(
"https://hello.{}:{}",
common::MAIN_DOMAIN_NAME,
endpoint_address.port()
))
.body(hyper::Body::empty())
.unwrap(),
)
.await
}
async fn path_h1_client(endpoint_address: &SocketAddr) -> (http::response::Parts, Bytes) {
let stream =
common::establish_tls_connection(common::MAIN_DOMAIN_NAME, endpoint_address, None).await;
common::do_get_request(
stream,
http::Version::HTTP_11,
&format!(
"https://{}:{}/hello/haha",
common::MAIN_DOMAIN_NAME,
endpoint_address.port()
),
&[(http::header::UPGRADE.as_str(), "1")],
)
.await
}
async fn path_h3_client(endpoint_address: &SocketAddr) -> (http::response::Parts, Bytes) {
let mut conn =
common::Http3Session::connect(endpoint_address, common::MAIN_DOMAIN_NAME, None).await;
conn.exchange(
Request::get(format!(
"https://{}:{}/hello/haha",
common::MAIN_DOMAIN_NAME,
endpoint_address.port()
))
.body(hyper::Body::empty())
.unwrap(),
)
.await
}
async fn run_endpoint(endpoint_address: &SocketAddr, proxy_address: &SocketAddr) {
let settings = Settings::builder()
.listen_address(endpoint_address)
.unwrap()
.listen_protocols(ListenProtocolSettings {
http1: Some(Http1Settings::builder().build()),
http2: Some(Http2Settings::builder().build()),
quic: Some(QuicSettings::builder().build()),
})
.reverse_proxy(
ReverseProxySettings::builder()
.server_address(proxy_address)
.unwrap()
.path_mask("/hello".to_string())
.build()
.unwrap(),
)
.allow_private_network_connections(true)
.build()
.unwrap();
let cert_key_file = common::make_cert_key_file();
let cert_key_path = cert_key_file.path.to_str().unwrap();
let hosts_settings = TlsHostsSettings::builder()
.main_hosts(vec![TlsHostInfo {
hostname: common::MAIN_DOMAIN_NAME.to_string(),
cert_chain_path: cert_key_path.to_string(),
private_key_path: cert_key_path.to_string(),
allowed_sni: vec![],
}])
.reverse_proxy_hosts(vec![TlsHostInfo {
hostname: format!("hello.{}", common::MAIN_DOMAIN_NAME),
cert_chain_path: cert_key_path.to_string(),
private_key_path: cert_key_path.to_string(),
allowed_sni: vec![],
}])
.build()
.unwrap();
common::run_endpoint_with_settings(settings, hosts_settings).await;
}
fn run_proxy() -> (SocketAddr, impl Future<Output = ()>) {
let server = std::net::TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).unwrap();
let _ = server.set_nonblocking(true);
let server_addr = server.local_addr().unwrap();
(server_addr, async move {
let (socket, peer) = TcpListener::from_std(server)
.unwrap()
.accept()
.await
.unwrap();
info!("New connection from {}", peer);
hyper::server::conn::Http::new()
.http1_only(true)
.serve_connection(socket, hyper::service::service_fn(request_handler))
.await
.unwrap();
})
}
async fn request_handler(
request: Request<hyper::Body>,
) -> Result<Response<hyper::Body>, hyper::Error> {
info!("Received request: {:?}", request);
Ok(Response::builder()
.body(hyper::Body::from("how much watch?"))
.unwrap())
}