From 50bbf3185a8eb3d6fca8074ef3da727add494f27 Mon Sep 17 00:00:00 2001 From: iHsin Date: Mon, 15 Jun 2026 21:28:42 +0800 Subject: [PATCH 1/2] test: expand unit/integration coverage; add client reconnect with tests Adds graceful-shutdown coverage across the cancellation chain, fills pure-function test gaps, exercises the SOCKS5 TCP/auth path end-to-end, and implements a configurable client auto-reconnect for the TUIC outbound with tests. Graceful shutdown: - Unit tests for the server-core `acceptor_loop` cancellation primitive plus `is_tuic_prefix`/`read_prefix` (wind-tuic server core). - Integration tests that the TUIC inbound listen loop and per-connection tasks drain on cancel; same for the SOCKS5 inbound accept loop and the tuic-client TCP/UDP forwarders. Pure-function coverage: - wind-core `is_private_ip`, `StackPrefer::from_str`, `pick_addr_by_preference`, `filter_addrs_by_preference`. - wind-socks `parse_udp_request_sync`, `unmap_v4_mapped`, `target_addr_to_socket`, `convert_addr`/`convert_to_socks_addr` round-trips. - wind-base `resolve_target` (IP-literal bypass + domain via resolver). SOCKS5 TCP end-to-end (dependency-free client): no-auth + password auth (right/wrong), IPv4 and domain CONNECT with echo round-trip, unsupported-command. Client reconnect (new feature): - TuicOutbound now holds the connection in an `ArcSwap`; a supervisor task runs one session until the connection drops, then reconnects with exponential backoff and swaps the fresh connection in. Shutdown is honored even mid- handshake. In-flight streams are not resurrected (callers retry). - Reconnect is configurable via `ReconnectConfig` (enabled + backoff bounds), wired from the client `Relay` config with serde defaults (backward compatible). - Tests: pure `next_backoff`, reconnect-after-restart, shutdown-while-reconnecting, reconnect-disabled, and config parsing/defaults. Co-Authored-By: Claude Opus 4.8 (1M context) --- crates/tuic-client/src/config.rs | 44 ++ crates/tuic-client/src/wind_adapter.rs | 9 +- crates/tuic-client/tests/graceful_shutdown.rs | 59 +++ crates/tuic-tests/tests/integration_tests.rs | 3 + crates/wind-base/Cargo.toml | 3 + crates/wind-base/src/resolve.rs | 58 +++ crates/wind-core/src/resolve.rs | 73 +++ crates/wind-core/src/utils.rs | 79 +++ crates/wind-socks/Cargo.toml | 2 +- crates/wind-socks/src/lib.rs | 34 ++ crates/wind-socks/src/udp.rs | 122 +++++ crates/wind-socks/tests/graceful_shutdown.rs | 97 ++++ crates/wind-socks/tests/socks_tcp.rs | 227 +++++++++ crates/wind-test/src/tuic.rs | 305 +++++++++++- crates/wind-tuic/src/quinn/outbound.rs | 464 +++++++++++++----- crates/wind-tuic/src/quinn/tls.rs | 1 + crates/wind-tuic/src/server/mod.rs | 203 ++++++++ crates/wind/src/conf/runtime.rs | 1 + 18 files changed, 1644 insertions(+), 140 deletions(-) create mode 100644 crates/tuic-client/tests/graceful_shutdown.rs create mode 100644 crates/wind-socks/tests/graceful_shutdown.rs create mode 100644 crates/wind-socks/tests/socks_tcp.rs diff --git a/crates/tuic-client/src/config.rs b/crates/tuic-client/src/config.rs index 35064d2..6b02914 100644 --- a/crates/tuic-client/src/config.rs +++ b/crates/tuic-client/src/config.rs @@ -158,6 +158,20 @@ pub struct Relay { #[educe(Default = None)] pub proxy: Option, + + /// Automatically reconnect to the relay after the connection drops. + #[educe(Default = true)] + pub reconnect: bool, + + /// Delay before the first reconnect attempt; doubled after each failure. + #[educe(Default(expression = Duration::from_millis(500)))] + #[serde(with = "humantime_serde")] + pub reconnect_initial_backoff: Duration, + + /// Upper bound on the reconnect backoff delay. + #[educe(Default(expression = Duration::from_secs(30)))] + #[serde(with = "humantime_serde")] + pub reconnect_max_backoff: Duration, } #[derive(Debug, Deserialize, serde::Serialize, Educe, Clone, PartialEq, Eq)] @@ -557,8 +571,38 @@ mod tests { assert_eq!(config.relay.gc_interval, Duration::from_secs(3)); assert_eq!(config.relay.gc_lifetime, Duration::from_secs(15)); assert!(!config.relay.skip_cert_verify); + // Reconnect defaults: enabled, 500ms initial backoff capped at 30s. + assert!(config.relay.reconnect); + assert_eq!(config.relay.reconnect_initial_backoff, Duration::from_millis(500)); + assert_eq!(config.relay.reconnect_max_backoff, Duration::from_secs(30)); assert_eq!(config.local.max_packet_size, 1500); } + + #[test] + fn test_reconnect_can_be_disabled_and_tuned() { + // A config omitting reconnect keeps the defaults (backward compatible). + let default_cfg = r#"{ "relay": { "server": "example.com:8443", "uuid": "00000000-0000-0000-0000-000000000000", "password": "pw" } }"#; + let config = test_parse_config(default_cfg, ".json5").unwrap(); + assert!(config.relay.reconnect); + + // Explicit values are honoured, including disabling reconnect and + // humantime-formatted backoff durations. + let tuned = r#"{ + "relay": { + "server": "example.com:8443", + "uuid": "00000000-0000-0000-0000-000000000000", + "password": "pw", + "reconnect": false, + "reconnect_initial_backoff": "2s", + "reconnect_max_backoff": "1m" + } + }"#; + let config = test_parse_config(tuned, ".json5").unwrap(); + assert!(!config.relay.reconnect); + assert_eq!(config.relay.reconnect_initial_backoff, Duration::from_secs(2)); + assert_eq!(config.relay.reconnect_max_backoff, Duration::from_secs(60)); + } + #[test] fn test_tcp_udp_forward() { let json5_config = include_str!("../tests/config/tcp_udp_forward.json5"); diff --git a/crates/tuic-client/src/wind_adapter.rs b/crates/tuic-client/src/wind_adapter.rs index 5971c5f..5c65a55 100644 --- a/crates/tuic-client/src/wind_adapter.rs +++ b/crates/tuic-client/src/wind_adapter.rs @@ -7,7 +7,7 @@ use std::{net::SocketAddr, sync::Arc}; use once_cell::sync::OnceCell; use wind_core::{AbstractOutbound, AppContext, tcp::AbstractTcpStream, types::TargetAddr, udp::UdpStream}; -use wind_tuic::quinn::outbound::{TuicOutbound, TuicOutboundOpts}; +use wind_tuic::quinn::outbound::{ReconnectConfig, TuicOutbound, TuicOutboundOpts}; use crate::config::Relay; @@ -62,6 +62,12 @@ impl TuicOutboundAdapter { } }; + let reconnect = ReconnectConfig { + enabled: relay.reconnect, + initial_backoff: relay.reconnect_initial_backoff, + max_backoff: relay.reconnect_max_backoff, + }; + let opts = TuicOutboundOpts { peer_addr: server_addr, sni, @@ -76,6 +82,7 @@ impl TuicOutboundAdapter { .into_iter() .map(|v| String::from_utf8_lossy(&v).to_string()) .collect(), + reconnect, }; let outbound: TuicOutbound = TuicOutbound::new(ctx, opts).await?; diff --git a/crates/tuic-client/tests/graceful_shutdown.rs b/crates/tuic-client/tests/graceful_shutdown.rs new file mode 100644 index 0000000..a5ec31f --- /dev/null +++ b/crates/tuic-client/tests/graceful_shutdown.rs @@ -0,0 +1,59 @@ +//! Graceful-shutdown test for the tuic-client TCP/UDP forwarders. +//! +//! `forward::start` spawns each forwarder into `ctx.tasks`, driven by a child +//! of `ctx.token`. Cancelling the token must break every accept/recv loop so +//! the tracker drains — this is the forwarder half of the client's +//! `run_with_cancel` shutdown path. + +use std::{net::SocketAddr, sync::Arc, time::Duration}; + +use tuic_client::{ + config::{TcpForward, UdpForward}, + forward, +}; +use wind_core::AppContext; + +/// Reserve a free loopback TCP port (the listener is dropped immediately so the +/// forwarder can bind it). +fn free_tcp_addr() -> SocketAddr { + let l = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let a = l.local_addr().unwrap(); + drop(l); + a +} + +/// Reserve a free loopback UDP port. +fn free_udp_addr() -> SocketAddr { + let s = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let a = s.local_addr().unwrap(); + drop(s); + a +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn forwarders_drain_on_cancel() { + let ctx = Arc::new(AppContext::default()); + + let tcp = vec![TcpForward { + listen: free_tcp_addr(), + // Discard port (9) — never actually dialed; the loops are idle. + remote: ("127.0.0.1".to_string(), 9), + }]; + let udp = vec![UdpForward { + listen: free_udp_addr(), + remote: ("127.0.0.1".to_string(), 9), + timeout: Duration::from_secs(60), + }]; + + forward::start(tcp, udp, &ctx).await; + + // Let both forwarder loops bind and reach their `select!`. + tokio::time::sleep(Duration::from_millis(200)).await; + + // Graceful shutdown: cancel the context token, then drain the tracker. + ctx.token.cancel(); + ctx.tasks.close(); + tokio::time::timeout(Duration::from_secs(5), ctx.tasks.wait()) + .await + .expect("forwarder tasks did not drain within 5s of cancellation"); +} diff --git a/crates/tuic-tests/tests/integration_tests.rs b/crates/tuic-tests/tests/integration_tests.rs index 4d55e0d..d5940a5 100644 --- a/crates/tuic-tests/tests/integration_tests.rs +++ b/crates/tuic-tests/tests/integration_tests.rs @@ -597,6 +597,9 @@ async fn test_ipv6_server_client_integration() -> eyre::Result<()> { gc_lifetime: Duration::from_secs(15), skip_cert_verify: true, proxy: None, + reconnect: true, + reconnect_initial_backoff: Duration::from_millis(500), + reconnect_max_backoff: Duration::from_secs(30), }, local: tuic_client::config::Local { server: "[::1]:1081".parse()?, diff --git a/crates/wind-base/Cargo.toml b/crates/wind-base/Cargo.toml index fe14baf..393321a 100644 --- a/crates/wind-base/Cargo.toml +++ b/crates/wind-base/Cargo.toml @@ -15,3 +15,6 @@ eyre = "0.6" tracing = "0.1" bytes = "1" async-trait = "0.1" + +[dev-dependencies] +tokio = { version = "1", default-features = false, features = ["macros", "rt"] } diff --git a/crates/wind-base/src/resolve.rs b/crates/wind-base/src/resolve.rs index 2fcef7f..1cdf9ea 100644 --- a/crates/wind-base/src/resolve.rs +++ b/crates/wind-base/src/resolve.rs @@ -10,3 +10,61 @@ pub async fn resolve_target(target: &TargetAddr, resolver: &dyn Resolver) -> eyr TargetAddr::Domain(domain, port) => Ok(SocketAddr::new(resolver.resolve(domain).await?, *port)), } } + +#[cfg(test)] +mod tests { + use std::{future::Future, net::IpAddr, pin::Pin}; + + use super::*; + + /// Resolver that returns a fixed IP for any host. + struct FixedResolver(IpAddr); + + impl Resolver for FixedResolver { + fn resolve<'a>(&'a self, _host: &'a str) -> Pin> + Send + 'a>> { + let ip = self.0; + Box::pin(async move { Ok(ip) }) + } + + fn resolve_all<'a>(&'a self, _host: &'a str) -> Pin>> + Send + 'a>> { + let ip = self.0; + Box::pin(async move { Ok(vec![ip]) }) + } + } + + /// Resolver that panics if used — proves IP-literal targets never hit DNS. + struct PanicResolver; + + impl Resolver for PanicResolver { + fn resolve<'a>(&'a self, _host: &'a str) -> Pin> + Send + 'a>> { + Box::pin(async { panic!("resolver must not be called for IP-literal targets") }) + } + + fn resolve_all<'a>(&'a self, _host: &'a str) -> Pin>> + Send + 'a>> { + Box::pin(async { panic!("resolver must not be called for IP-literal targets") }) + } + } + + #[tokio::test] + async fn ip_literal_targets_bypass_the_resolver() { + let v4 = resolve_target(&TargetAddr::IPv4("192.168.1.1".parse().unwrap(), 8080), &PanicResolver) + .await + .unwrap(); + assert_eq!(v4.to_string(), "192.168.1.1:8080"); + + let v6 = resolve_target(&TargetAddr::IPv6("::1".parse().unwrap(), 443), &PanicResolver) + .await + .unwrap(); + assert!(v6.ip().is_ipv6()); + assert_eq!(v6.port(), 443); + } + + #[tokio::test] + async fn domain_targets_use_the_resolver_and_keep_the_port() { + let resolver = FixedResolver("203.0.113.7".parse().unwrap()); + let s = resolve_target(&TargetAddr::Domain("example.com".into(), 443), &resolver) + .await + .unwrap(); + assert_eq!(s.to_string(), "203.0.113.7:443"); + } +} diff --git a/crates/wind-core/src/resolve.rs b/crates/wind-core/src/resolve.rs index 0137f77..629b7bc 100644 --- a/crates/wind-core/src/resolve.rs +++ b/crates/wind-core/src/resolve.rs @@ -93,3 +93,76 @@ pub fn filter_addrs_by_preference(addrs: Vec, prefer: StackPrefer) -> Ve } } } + +#[cfg(test)] +mod tests { + use super::*; + + fn ips(list: &[&str]) -> Vec { + list.iter().map(|s| s.parse().unwrap()).collect() + } + + #[test] + fn pick_only_modes_require_matching_family() { + let mixed = ips(&["192.168.1.1", "2001:db8::1"]); + assert!(pick_addr_by_preference(mixed.clone(), StackPrefer::V4only).unwrap().is_ipv4()); + assert!(pick_addr_by_preference(mixed, StackPrefer::V6only).unwrap().is_ipv6()); + + assert!(pick_addr_by_preference(ips(&["2001:db8::1"]), StackPrefer::V4only).is_none()); + assert!(pick_addr_by_preference(ips(&["192.168.1.1"]), StackPrefer::V6only).is_none()); + } + + #[test] + fn pick_first_modes_fall_back_to_other_family() { + assert!( + pick_addr_by_preference(ips(&["2001:db8::1", "192.168.1.1"]), StackPrefer::V4first) + .unwrap() + .is_ipv4() + ); + // V4first with no IPv4 falls back to IPv6. + assert!( + pick_addr_by_preference(ips(&["2001:db8::1"]), StackPrefer::V4first) + .unwrap() + .is_ipv6() + ); + + assert!( + pick_addr_by_preference(ips(&["192.168.1.1", "2001:db8::1"]), StackPrefer::V6first) + .unwrap() + .is_ipv6() + ); + // V6first with no IPv6 falls back to IPv4. + assert!( + pick_addr_by_preference(ips(&["192.168.1.1"]), StackPrefer::V6first) + .unwrap() + .is_ipv4() + ); + } + + #[test] + fn pick_empty_list_is_none() { + assert!(pick_addr_by_preference(vec![], StackPrefer::V4first).is_none()); + } + + #[test] + fn filter_only_modes_keep_a_single_family() { + let addrs = ips(&["192.168.1.1", "2001:db8::1", "10.0.0.1"]); + let v4 = filter_addrs_by_preference(addrs.clone(), StackPrefer::V4only); + assert_eq!(v4, ips(&["192.168.1.1", "10.0.0.1"])); + let v6 = filter_addrs_by_preference(addrs, StackPrefer::V6only); + assert_eq!(v6, ips(&["2001:db8::1"])); + } + + #[test] + fn filter_first_modes_group_preferred_family_first_preserving_order() { + let addrs = ips(&["2001:db8::1", "192.168.1.1", "::1", "10.0.0.1"]); + assert_eq!( + filter_addrs_by_preference(addrs.clone(), StackPrefer::V4first), + ips(&["192.168.1.1", "10.0.0.1", "2001:db8::1", "::1"]), + ); + assert_eq!( + filter_addrs_by_preference(addrs, StackPrefer::V6first), + ips(&["2001:db8::1", "::1", "192.168.1.1", "10.0.0.1"]), + ); + } +} diff --git a/crates/wind-core/src/utils.rs b/crates/wind-core/src/utils.rs index a6a9316..cbd6415 100644 --- a/crates/wind-core/src/utils.rs +++ b/crates/wind-core/src/utils.rs @@ -71,3 +71,82 @@ pub fn is_private_ip(ip: &IpAddr) -> bool { } } } + +#[cfg(test)] +mod tests { + use std::net::IpAddr; + + use super::*; + + fn ip(s: &str) -> IpAddr { + s.parse().unwrap() + } + + #[test] + fn private_ipv4_ranges_are_private() { + for s in [ + "10.0.0.1", + "10.255.255.255", + "172.16.0.0", + "172.31.255.255", + "192.168.0.1", + "192.168.255.255", + "169.254.0.1", + ] { + assert!(is_private_ip(&ip(s)), "{s} should be private"); + } + } + + #[test] + fn public_ipv4_and_boundary_ranges_are_public() { + for s in [ + "8.8.8.8", + "1.1.1.1", + "11.0.0.1", + "172.15.255.255", // just below the 172.16/12 block + "172.32.0.0", // just above it + "192.167.255.255", + "169.253.0.1", + ] { + assert!(!is_private_ip(&ip(s)), "{s} should be public"); + } + } + + #[test] + fn private_ipv6_ranges_are_private() { + // fc00::/7 (fc.. and fd..) and fe80::/10 (fe80.. through febf..). + for s in ["fc00::1", "fd00::1", "fe80::1", "febf::1"] { + assert!(is_private_ip(&ip(s)), "{s} should be private"); + } + } + + #[test] + fn public_ipv6_ranges_are_public() { + // 2001:db8 doc range, loopback, and fec0 (outside fe80::/10). + for s in ["2001:db8::1", "::1", "fec0::1"] { + assert!(!is_private_ip(&ip(s)), "{s} should be public"); + } + } + + #[test] + fn stack_prefer_parses_all_aliases_case_insensitively() { + for s in ["v4", "v4only", "only_v4", "V4ONLY"] { + assert_eq!(s.parse::(), Ok(StackPrefer::V4only), "{s}"); + } + for s in ["v6", "v6only", "only_v6"] { + assert_eq!(s.parse::(), Ok(StackPrefer::V6only), "{s}"); + } + for s in ["v4v6", "v4first", "prefer_v4", "auto"] { + assert_eq!(s.parse::(), Ok(StackPrefer::V4first), "{s}"); + } + for s in ["v6v4", "v6first", "prefer_v6"] { + assert_eq!(s.parse::(), Ok(StackPrefer::V6first), "{s}"); + } + } + + #[test] + fn stack_prefer_rejects_unknown() { + assert!("nonsense".parse::().is_err()); + assert!("".parse::().is_err()); + } +} diff --git a/crates/wind-socks/Cargo.toml b/crates/wind-socks/Cargo.toml index 59b65cf..7d1f6ba 100644 --- a/crates/wind-socks/Cargo.toml +++ b/crates/wind-socks/Cargo.toml @@ -25,6 +25,6 @@ tracing = "0.1" async-trait = "0.1" [dev-dependencies] -tokio = { version = "1", default-features = false, features = ["macros", "rt", "rt-multi-thread", "net", "time", "sync"] } +tokio = { version = "1", default-features = false, features = ["macros", "rt", "rt-multi-thread", "net", "time", "sync", "io-util"] } eyre = "0.6" bytes = "1" diff --git a/crates/wind-socks/src/lib.rs b/crates/wind-socks/src/lib.rs index 246f777..9d9fe84 100644 --- a/crates/wind-socks/src/lib.rs +++ b/crates/wind-socks/src/lib.rs @@ -70,3 +70,37 @@ pub fn convert_to_socks_addr(addr: &TargetAddr) -> SocksTargetAddr { TargetAddr::IPv6(ipv6, port) => SocksTargetAddr::Ip(SocketAddr::V6(std::net::SocketAddrV6::new(*ipv6, *port, 0, 0))), } } + +#[cfg(test)] +mod tests { + use std::net::{Ipv4Addr, Ipv6Addr}; + + use super::*; + + #[test] + fn convert_addr_maps_each_family() { + assert_eq!( + convert_addr(&SocksTargetAddr::Domain("example.com".into(), 443)), + TargetAddr::Domain("example.com".into(), 443) + ); + assert_eq!( + convert_addr(&SocksTargetAddr::Ip("192.168.1.1:80".parse().unwrap())), + TargetAddr::IPv4(Ipv4Addr::new(192, 168, 1, 1), 80) + ); + assert_eq!( + convert_addr(&SocksTargetAddr::Ip("[2001:db8::1]:443".parse().unwrap())), + TargetAddr::IPv6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1), 443) + ); + } + + #[test] + fn convert_addr_and_back_is_identity() { + for t in [ + TargetAddr::Domain("example.com".into(), 443), + TargetAddr::IPv4(Ipv4Addr::new(10, 0, 0, 1), 8080), + TargetAddr::IPv6(Ipv6Addr::LOCALHOST, 53), + ] { + assert_eq!(convert_addr(&convert_to_socks_addr(&t)), t, "roundtrip for {t:?}"); + } + } +} diff --git a/crates/wind-socks/src/udp.rs b/crates/wind-socks/src/udp.rs index 471136d..5dab9d9 100644 --- a/crates/wind-socks/src/udp.rs +++ b/crates/wind-socks/src/udp.rs @@ -250,3 +250,125 @@ fn unmap_v4_mapped(ip: std::net::IpAddr) -> std::net::IpAddr { v4 => v4, } } + +#[cfg(test)] +mod tests { + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + + use super::*; + + #[test] + fn parse_ipv4_request() { + let mut data = vec![0x00, 0x00, 0x00, 0x01]; + data.extend_from_slice(&[192, 168, 1, 1]); + data.extend_from_slice(&[0x1f, 0x90]); // port 8080 + data.extend_from_slice(b"payload"); + + let (frag, addr, payload) = parse_udp_request_sync(&data).unwrap(); + assert_eq!(frag, 0); + assert_eq!(payload, b"payload"); + match addr { + SocksTargetAddr::Ip(SocketAddr::V4(v4)) => { + assert_eq!(*v4.ip(), Ipv4Addr::new(192, 168, 1, 1)); + assert_eq!(v4.port(), 8080); + } + _ => panic!("expected IPv4 target"), + } + } + + #[test] + fn parse_domain_request() { + let mut data = vec![0x00, 0x00, 0x00, 0x03, 11]; + data.extend_from_slice(b"example.com"); + data.extend_from_slice(&[0x01, 0xbb]); // port 443 + data.extend_from_slice(b"hi"); + + let (_, addr, payload) = parse_udp_request_sync(&data).unwrap(); + assert_eq!(payload, b"hi"); + match addr { + SocksTargetAddr::Domain(d, p) => { + assert_eq!(d, "example.com"); + assert_eq!(p, 443); + } + _ => panic!("expected domain target"), + } + } + + #[test] + fn parse_ipv6_request() { + let ip = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1); + let mut data = vec![0x00, 0x00, 0x00, 0x04]; + data.extend_from_slice(&ip.octets()); + data.extend_from_slice(&[0x01, 0xbb]); // port 443 + + let (_, addr, payload) = parse_udp_request_sync(&data).unwrap(); + assert!(payload.is_empty()); + match addr { + SocksTargetAddr::Ip(SocketAddr::V6(v6)) => { + assert_eq!(*v6.ip(), ip); + assert_eq!(v6.port(), 443); + } + _ => panic!("expected IPv6 target"), + } + } + + #[test] + fn parse_rejects_malformed_headers() { + // Too short for the 4-byte header. + assert!(parse_udp_request_sync(&[0x00, 0x00]).is_err()); + // Non-zero reserved bytes. + assert!(parse_udp_request_sync(&[0x01, 0x00, 0x00, 0x01]).is_err()); + // Unsupported address type. + assert!(parse_udp_request_sync(&[0x00, 0x00, 0x00, 0x05]).is_err()); + // IPv4 atyp but truncated address. + assert!(parse_udp_request_sync(&[0x00, 0x00, 0x00, 0x01, 1, 2]).is_err()); + // Domain atyp, length 5, but only one domain byte present. + assert!(parse_udp_request_sync(&[0x00, 0x00, 0x00, 0x03, 5, b'a']).is_err()); + } + + #[test] + fn unmap_v4_mapped_unwraps_only_mapped_addresses() { + let v4 = Ipv4Addr::new(192, 168, 1, 1); + assert_eq!(unmap_v4_mapped(IpAddr::V6(v4.to_ipv6_mapped())), IpAddr::V4(v4)); + + let pure_v6: IpAddr = "2001:db8::1".parse().unwrap(); + assert_eq!(unmap_v4_mapped(pure_v6), pure_v6); + + let pure_v4 = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)); + assert_eq!(unmap_v4_mapped(pure_v4), pure_v4); + } + + #[test] + fn target_addr_to_socket_maps_families_and_domain_fallback() { + assert_eq!( + target_addr_to_socket(&TargetAddr::IPv4(Ipv4Addr::new(192, 168, 1, 1), 443)).to_string(), + "192.168.1.1:443" + ); + + let v6 = target_addr_to_socket(&TargetAddr::IPv6("::1".parse().unwrap(), 8080)); + assert!(v6.ip().is_ipv6()); + assert_eq!(v6.port(), 8080); + + // RFC 1928 has no "host" codepoint, so domains report 0.0.0.0:port. + assert_eq!( + target_addr_to_socket(&TargetAddr::Domain("example.com".into(), 53)).to_string(), + "0.0.0.0:53" + ); + } + + #[test] + fn convert_target_addr_maps_each_family() { + assert_eq!( + convert_target_addr(&SocksTargetAddr::Ip("192.168.1.1:80".parse().unwrap())), + TargetAddr::IPv4(Ipv4Addr::new(192, 168, 1, 1), 80) + ); + assert_eq!( + convert_target_addr(&SocksTargetAddr::Ip("[2001:db8::1]:443".parse().unwrap())), + TargetAddr::IPv6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1), 443) + ); + assert_eq!( + convert_target_addr(&SocksTargetAddr::Domain("example.com".into(), 443)), + TargetAddr::Domain("example.com".into(), 443) + ); + } +} diff --git a/crates/wind-socks/tests/graceful_shutdown.rs b/crates/wind-socks/tests/graceful_shutdown.rs new file mode 100644 index 0000000..14966be --- /dev/null +++ b/crates/wind-socks/tests/graceful_shutdown.rs @@ -0,0 +1,97 @@ +//! Graceful-shutdown tests for the SOCKS5 inbound accept loop. +//! +//! `SocksInbound::listen` selects on its `CancellationToken`; on cancellation +//! it breaks out of the accept loop, closes its per-connection `TaskTracker`, +//! and waits for in-flight sessions before returning. Each session also holds a +//! child token so a blocked handshake is aborted promptly. These tests confirm +//! the loop returns within a bounded time — both idle and with a live +//! connection — rather than hanging on shutdown. + +use std::{net::SocketAddr, time::Duration}; + +use tokio_util::sync::CancellationToken; +use wind_core::{AbstractInbound, InboundCallback, tcp::AbstractTcpStream, types::TargetAddr, udp::UdpStream}; +use wind_socks::inbound::{AuthMode, SocksInbound, SocksInboundOpt}; + +/// A callback whose handlers never complete on their own. Forces shutdown to be +/// driven by the cancellation chain (the per-connection child token), not by a +/// session finishing naturally. +#[derive(Clone)] +struct ParkCallback; + +impl InboundCallback for ParkCallback { + async fn handle_tcpstream(&self, _target: TargetAddr, _stream: impl AbstractTcpStream + 'static) -> eyre::Result<()> { + std::future::pending::<()>().await; + Ok(()) + } + + async fn handle_udpstream(&self, _udp_stream: UdpStream) -> eyre::Result<()> { + std::future::pending::<()>().await; + Ok(()) + } +} + +/// Bind a SOCKS inbound on a free loopback port and spawn its `listen` loop, +/// returning the bound address and the loop's join handle. +async fn spawn_inbound(cancel: CancellationToken) -> (SocketAddr, tokio::task::JoinHandle>) { + // Reserve a free port without holding the listener. + let probe = std::net::TcpListener::bind("127.0.0.1:0").expect("reserve port"); + let addr = probe.local_addr().unwrap(); + drop(probe); + + let opts = SocksInboundOpt { + listen_addr: addr, + public_addr: None, + auth: AuthMode::NoAuth, + skip_auth: false, + allow_udp: false, + }; + let inbound = SocksInbound::new(opts, cancel); + let handle = tokio::spawn(async move { + let cb = ParkCallback; + inbound.listen(&cb).await + }); + + // Let the loop bind and reach `accept()`. + tokio::time::sleep(Duration::from_millis(200)).await; + (addr, handle) +} + +/// An idle inbound must break its accept loop and return `Ok(())` promptly +/// after the token is cancelled. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn idle_listen_loop_exits_on_cancel() { + let cancel = CancellationToken::new(); + let (_addr, handle) = spawn_inbound(cancel.clone()).await; + + cancel.cancel(); + + let res = tokio::time::timeout(Duration::from_secs(5), handle) + .await + .expect("listen loop did not exit within 5s of cancellation") + .expect("listen task panicked"); + assert!(res.is_ok(), "listen returned an error on shutdown: {:?}", res.err()); +} + +/// With a live connection parked mid-session, cancellation must abort the +/// in-flight session (via its child token), let the per-connection +/// `TaskTracker` drain, and return — not block forever in `conn_tasks.wait()`. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn active_session_is_drained_on_cancel() { + let cancel = CancellationToken::new(); + let (addr, handle) = spawn_inbound(cancel.clone()).await; + + // Open a connection so the inbound spawns a tracked session task. We don't + // drive the SOCKS handshake — the handler is parked reading from the stream, + // which is exactly the in-flight state shutdown must abort. + let _client = tokio::net::TcpStream::connect(addr).await.expect("connect to inbound"); + tokio::time::sleep(Duration::from_millis(200)).await; + + cancel.cancel(); + + let res = tokio::time::timeout(Duration::from_secs(5), handle) + .await + .expect("listen loop did not drain in-flight session within 5s of cancellation") + .expect("listen task panicked"); + assert!(res.is_ok(), "listen returned an error on shutdown: {:?}", res.err()); +} diff --git a/crates/wind-socks/tests/socks_tcp.rs b/crates/wind-socks/tests/socks_tcp.rs new file mode 100644 index 0000000..767de73 --- /dev/null +++ b/crates/wind-socks/tests/socks_tcp.rs @@ -0,0 +1,227 @@ +//! End-to-end TCP tests for the SOCKS5 inbound, exercising the real handshake +//! over loopback: method negotiation (no-auth / RFC 1929 username-password), +//! CONNECT to IPv4 and domain targets, and the unsupported-command path. +//! +//! A minimal hand-rolled SOCKS5 client is used so the test depends only on the +//! wire protocol, not on any client library version. The inbound is wired to a +//! TCP-relay callback that dials the real target and copies bytes both ways, so +//! a successful CONNECT yields a genuine echo roundtrip. + +use std::{ + net::{Ipv4Addr, SocketAddr}, + time::Duration, +}; + +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpListener, TcpStream}, +}; +use tokio_util::sync::CancellationToken; +use wind_core::{AbstractInbound, InboundCallback, tcp::AbstractTcpStream, types::TargetAddr, udp::UdpStream}; +use wind_socks::inbound::{AuthMode, SocksInbound, SocksInboundOpt}; + +/// Inbound callback that relays an accepted SOCKS5 TCP stream to its real +/// target and copies bytes bidirectionally. +#[derive(Clone)] +struct TcpRelayCallback; + +impl InboundCallback for TcpRelayCallback { + async fn handle_tcpstream(&self, target: TargetAddr, mut stream: impl AbstractTcpStream + 'static) -> eyre::Result<()> { + let mut upstream = TcpStream::connect(target.to_string()).await?; + tokio::io::copy_bidirectional(&mut stream, &mut upstream).await?; + Ok(()) + } + + async fn handle_udpstream(&self, _udp_stream: UdpStream) -> eyre::Result<()> { + Ok(()) + } +} + +/// Spawn a TCP echo server on loopback; returns its address. +async fn spawn_echo() -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + while let Ok((mut sock, _)) = listener.accept().await { + tokio::spawn(async move { + let mut buf = [0u8; 4096]; + loop { + match sock.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(n) => { + if sock.write_all(&buf[..n]).await.is_err() { + break; + } + } + } + } + }); + } + }); + addr +} + +/// Spawn a SOCKS5 inbound with the given auth mode; returns its address and a +/// cancel token to shut it down. +async fn spawn_socks(auth: AuthMode) -> (SocketAddr, CancellationToken) { + let probe = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = probe.local_addr().unwrap(); + drop(probe); + + let opts = SocksInboundOpt { + listen_addr: addr, + public_addr: None, + auth, + skip_auth: false, + allow_udp: false, + }; + let cancel = CancellationToken::new(); + let inbound = SocksInbound::new(opts, cancel.clone()); + tokio::spawn(async move { + let cb = TcpRelayCallback; + let _ = inbound.listen(&cb).await; + }); + + tokio::time::sleep(Duration::from_millis(200)).await; + (addr, cancel) +} + +/// Negotiate the no-auth method (0x00). Panics if the server doesn't select it. +async fn negotiate_no_auth(s: &mut TcpStream) { + s.write_all(&[0x05, 0x01, 0x00]).await.unwrap(); + let mut resp = [0u8; 2]; + s.read_exact(&mut resp).await.unwrap(); + assert_eq!(resp, [0x05, 0x00], "server must select no-auth"); +} + +/// Run the RFC 1929 username/password sub-negotiation. Returns the status byte +/// (0x00 = success), or `Err` if the server closed the stream on rejection. +async fn negotiate_password(s: &mut TcpStream, user: &str, pass: &str) -> std::io::Result { + s.write_all(&[0x05, 0x01, 0x02]).await?; + let mut method = [0u8; 2]; + s.read_exact(&mut method).await?; + if method[1] != 0x02 { + // Server refused the username/password method outright. + return Ok(method[1]); + } + let mut req = vec![0x01, user.len() as u8]; + req.extend_from_slice(user.as_bytes()); + req.push(pass.len() as u8); + req.extend_from_slice(pass.as_bytes()); + s.write_all(&req).await?; + + let mut status = [0u8; 2]; + s.read_exact(&mut status).await?; + Ok(status[1]) +} + +/// Send a CONNECT request and return the reply code (0x00 = success). The +/// inbound always replies with an IPv4 BND.ADDR, so the reply is 10 bytes. +async fn connect_request(s: &mut TcpStream, atyp_body: Vec) -> u8 { + let mut req = vec![0x05, 0x01, 0x00]; + req.extend_from_slice(&atyp_body); + s.write_all(&req).await.unwrap(); + let mut reply = [0u8; 10]; + s.read_exact(&mut reply).await.unwrap(); + reply[1] +} + +fn ipv4_body(ip: Ipv4Addr, port: u16) -> Vec { + let mut b = vec![0x01]; + b.extend_from_slice(&ip.octets()); + b.extend_from_slice(&port.to_be_bytes()); + b +} + +fn domain_body(host: &str, port: u16) -> Vec { + let mut b = vec![0x03, host.len() as u8]; + b.extend_from_slice(host.as_bytes()); + b.extend_from_slice(&port.to_be_bytes()); + b +} + +/// Assert a CONNECTed stream echoes a payload back unchanged. +async fn assert_echo_roundtrip(s: &mut TcpStream) { + let msg = b"hello socks5"; + s.write_all(msg).await.unwrap(); + let mut buf = [0u8; 12]; + let read = tokio::time::timeout(Duration::from_secs(5), s.read_exact(&mut buf)).await; + assert!(read.is_ok(), "echo read timed out"); + read.unwrap().unwrap(); + assert_eq!(&buf, msg, "echoed payload must match"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn no_auth_tcp_connect_ipv4_echoes() { + let echo = spawn_echo().await; + let (proxy, _cancel) = spawn_socks(AuthMode::NoAuth).await; + + let mut s = TcpStream::connect(proxy).await.unwrap(); + negotiate_no_auth(&mut s).await; + let rep = connect_request(&mut s, ipv4_body(Ipv4Addr::LOCALHOST, echo.port())).await; + assert_eq!(rep, 0x00, "CONNECT must succeed"); + assert_echo_roundtrip(&mut s).await; +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn no_auth_tcp_connect_domain_echoes() { + let echo = spawn_echo().await; + let (proxy, _cancel) = spawn_socks(AuthMode::NoAuth).await; + + let mut s = TcpStream::connect(proxy).await.unwrap(); + negotiate_no_auth(&mut s).await; + // Domain target: the relay callback resolves "localhost" itself. + let rep = connect_request(&mut s, domain_body("localhost", echo.port())).await; + assert_eq!(rep, 0x00, "CONNECT to domain target must succeed"); + assert_echo_roundtrip(&mut s).await; +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn password_auth_accepts_correct_credentials() { + let echo = spawn_echo().await; + let (proxy, _cancel) = spawn_socks(AuthMode::Password { + username: "alice".into(), + password: "s3cret".into(), + }) + .await; + + let mut s = TcpStream::connect(proxy).await.unwrap(); + let status = negotiate_password(&mut s, "alice", "s3cret").await.unwrap(); + assert_eq!(status, 0x00, "correct credentials must authenticate"); + + let rep = connect_request(&mut s, ipv4_body(Ipv4Addr::LOCALHOST, echo.port())).await; + assert_eq!(rep, 0x00, "CONNECT must succeed after auth"); + assert_echo_roundtrip(&mut s).await; +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn password_auth_rejects_wrong_credentials() { + let (proxy, _cancel) = spawn_socks(AuthMode::Password { + username: "alice".into(), + password: "s3cret".into(), + }) + .await; + + let mut s = TcpStream::connect(proxy).await.unwrap(); + let result = negotiate_password(&mut s, "alice", "wrong").await; + // Either the server reports a non-zero status, or it closes the stream. + let rejected = matches!(result, Ok(st) if st != 0x00) || result.is_err(); + assert!(rejected, "wrong credentials must be rejected, got {result:?}"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn unsupported_command_is_rejected() { + let (proxy, _cancel) = spawn_socks(AuthMode::NoAuth).await; + + let mut s = TcpStream::connect(proxy).await.unwrap(); + negotiate_no_auth(&mut s).await; + + // CMD = 0x02 (BIND) is not supported by this inbound. + let mut req = vec![0x05, 0x02, 0x00]; + req.extend_from_slice(&ipv4_body(Ipv4Addr::LOCALHOST, 9)); + s.write_all(&req).await.unwrap(); + + let mut reply = [0u8; 10]; + s.read_exact(&mut reply).await.unwrap(); + assert_ne!(reply[1], 0x00, "BIND must not be reported as success"); +} diff --git a/crates/wind-test/src/tuic.rs b/crates/wind-test/src/tuic.rs index baab21c..707c740 100644 --- a/crates/wind-test/src/tuic.rs +++ b/crates/wind-test/src/tuic.rs @@ -140,7 +140,7 @@ mod tests { use wind_core::{AbstractInbound, AbstractOutbound}; use wind_tuic::quinn::{ inbound::{TuicInbound, TuicInboundOpts}, - outbound::{TuicOutbound, TuicOutboundOpts}, + outbound::{ReconnectConfig, TuicOutbound, TuicOutboundOpts}, }; use super::*; @@ -160,6 +160,21 @@ mod tests { } async fn setup_tuic_server() -> eyre::Result { + let (ctx, server_addr, uuid, _listen) = spawn_tuic_server().await?; + // The listen task is left detached: `TuicTestSetup::drop` cancels the + // context token, which makes the accept loop break on its own. + Ok(TuicTestSetup { server_addr, uuid, ctx }) + } + + /// Lower-level server bring-up that also hands back the `listen` + /// accept-loop join handle (and the context, so the caller can cancel it). + /// + /// Used by the graceful-shutdown tests, which need to *await* the accept + /// loop to prove it exits on cancellation — something + /// [`setup_tuic_server`] can't expose because [`TuicTestSetup`] owns a + /// `Drop` guard and can't surrender a field by move. + async fn spawn_tuic_server() -> eyre::Result<(Arc, SocketAddr, Uuid, tokio::task::JoinHandle>)> + { let (cert, key) = generate_tuic_test_cert(); let uuid = Uuid::new_v4(); let mut users = HashMap::new(); @@ -185,17 +200,47 @@ mod tests { let server = TuicInbound::new(ctx.clone(), opts); let callback = Arc::new(DirectCallback); + let listen = tokio::spawn(async move { server.listen(callback.as_ref()).await }.in_current_span()); + + // Allow the server time to bind and begin accepting + tokio::time::sleep(Duration::from_millis(300)).await; + + Ok((ctx, server_addr, uuid, listen)) + } + + /// Connect a TUIC client to `addr`/`uuid` and start its heartbeat poll. + /// Used by the graceful-shutdown tests, which need a client whose lifetime + /// is independent of the [`TuicTestSetup`] `Drop` guard. + async fn connect_client(addr: SocketAddr, uuid: Uuid) -> eyre::Result> { + connect_client_with(addr, uuid, ReconnectConfig::default()).await + } + + /// As [`connect_client`], but with an explicit reconnect policy — lets + /// tests disable reconnect or tune its backoff. + async fn connect_client_with(addr: SocketAddr, uuid: Uuid, reconnect: ReconnectConfig) -> eyre::Result> { + let ctx = Arc::new(AppContext::default()); + let opts = TuicOutboundOpts { + peer_addr: addr, + sni: "localhost".to_string(), + auth: (uuid, Arc::from(TEST_PASSWORD)), + zero_rtt_handshake: false, + heartbeat: Duration::from_secs(5), + gc_interval: Duration::from_secs(5), + gc_lifetime: Duration::from_secs(30), + skip_cert_verify: true, + alpn: vec!["h3".to_string()], + reconnect, + }; + let client = Arc::new(TuicOutbound::new(ctx, opts).await?); + let poll_client = client.clone(); tokio::spawn( async move { - let _ = server.listen(callback.as_ref()).await; + let _ = poll_client.start_poll().await; } .in_current_span(), ); - - // Allow the server time to bind and begin accepting - tokio::time::sleep(Duration::from_millis(300)).await; - - Ok(TuicTestSetup { server_addr, uuid, ctx }) + tokio::time::sleep(Duration::from_millis(150)).await; + Ok(client) } async fn connect_tuic_client(setup: &TuicTestSetup) -> eyre::Result> { @@ -210,6 +255,7 @@ mod tests { gc_lifetime: Duration::from_secs(30), skip_cert_verify: true, alpn: vec!["h3".to_string()], + reconnect: ReconnectConfig::default(), }; let client: std::sync::Arc = std::sync::Arc::new(TuicOutbound::new(ctx.clone(), opts).await?); let poll_client = client.clone(); @@ -257,6 +303,7 @@ mod tests { gc_lifetime: Duration::from_secs(30), skip_cert_verify: true, alpn: vec!["h3".to_string()], + reconnect: ReconnectConfig::default(), }; let result: eyre::Result = TuicOutbound::new(ctx, opts).await; assert!( @@ -283,6 +330,7 @@ mod tests { gc_lifetime: Duration::from_secs(30), skip_cert_verify: true, alpn: vec!["h3".to_string()], + reconnect: ReconnectConfig::default(), }; let result: eyre::Result = TuicOutbound::new(ctx, opts).await; assert!( @@ -513,4 +561,247 @@ mod tests { Ok(()) }); } + + /// Graceful shutdown — idle server. Cancelling the context token must make + /// the inbound `listen` accept-loop break, close the QUIC endpoint, and + /// return `Ok(())` within a bounded time rather than hanging. + #[tokio::test] + async fn test_graceful_shutdown_idle_listen_loop_exits() { + let (ctx, _addr, _uuid, listen) = spawn_tuic_server().await.expect("Failed to start TUIC server"); + + ctx.token.cancel(); + + let joined = tokio::time::timeout(Duration::from_secs(5), listen) + .await + .expect("listen loop did not exit within 5s of cancellation") + .expect("listen task panicked"); + assert!(joined.is_ok(), "listen returned an error on shutdown: {:?}", joined.err()); + } + + /// Graceful shutdown — active connection. With a live client connected, the + /// server's per-connection handler is spawned into `ctx.tasks`. After + /// cancellation, the whole cancellation chain (listen loop → + /// `serve_connection` → acceptor tasks) must wind down so the tracked + /// tasks drain and the accept loop returns — all within a bounded time. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_graceful_shutdown_drains_active_connection() { + let (ctx, addr, uuid, listen) = spawn_tuic_server().await.expect("Failed to start TUIC server"); + + // Establish a connection so the server spawns a handler into ctx.tasks. + let _client = connect_client(addr, uuid).await.expect("Failed to connect TUIC client"); + // Give the server's accept loop a moment to register the connection + // handler in the tracker. + tokio::time::sleep(Duration::from_millis(200)).await; + + // Trigger graceful shutdown. + ctx.token.cancel(); + + // The connection handler(s) tracked in ctx.tasks must finish promptly — + // before the cancellation-chain fix, the acceptor tasks would keep + // `serve_connection` alive and this would hang. + ctx.tasks.close(); + tokio::time::timeout(Duration::from_secs(5), ctx.tasks.wait()) + .await + .expect("tracked connection tasks did not drain within 5s of cancellation"); + + // And the accept loop itself exits cleanly. + let joined = tokio::time::timeout(Duration::from_secs(5), listen) + .await + .expect("listen loop did not exit within 5s of cancellation") + .expect("listen task panicked"); + assert!(joined.is_ok(), "listen returned an error on shutdown: {:?}", joined.err()); + } + + /// Spawn a TUIC relay server bound to a fixed address with a known user; + /// returns its context (to cancel) and the listen-loop join handle. Lets + /// the reconnect test restart a server on the same address. + async fn spawn_server_on(addr: SocketAddr, uuid: Uuid) -> (Arc, tokio::task::JoinHandle>) { + let (cert, key) = generate_tuic_test_cert(); + let mut users = HashMap::new(); + users.insert(uuid, String::from_utf8_lossy(TEST_PASSWORD).to_string()); + + let ctx = Arc::new(AppContext::default()); + let opts = TuicInboundOpts { + listen_addr: addr, + certificate: cert, + private_key: key, + alpn: vec!["h3".to_string()], + users, + auth_timeout: Duration::from_secs(5), + max_idle_time: Duration::from_secs(30), + zero_rtt: false, + ..Default::default() + }; + let server = TuicInbound::new(ctx.clone(), opts); + let callback = Arc::new(DirectCallback); + let handle = tokio::spawn(async move { server.listen(callback.as_ref()).await }.in_current_span()); + tokio::time::sleep(Duration::from_millis(300)).await; + (ctx, handle) + } + + /// Attempt one proxied TCP echo through the client. Returns the echoed + /// bytes, or an error if the relay/connection is currently unavailable + /// (e.g. mid reconnect). + async fn proxy_echo_once(client: &Arc, echo_port: u16, msg: &[u8]) -> eyre::Result> { + let (mut local, remote) = tokio::io::duplex(4096); + let target = TargetAddr::IPv4(std::net::Ipv4Addr::LOCALHOST, echo_port); + let c = client.clone(); + tokio::spawn(async move { + let _ = c.handle_tcp(target, remote, None::).await; + }); + local.write_all(msg).await?; + let mut buf = vec![0u8; msg.len()]; + tokio::time::timeout(Duration::from_secs(2), local.read_exact(&mut buf)).await??; + Ok(buf) + } + + /// The client must transparently reconnect after the relay server restarts + /// on the same address: a new proxied request succeeds over the fresh + /// connection even though the original connection (and its streams) were + /// torn down. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_client_reconnects_after_server_restart() { + use tokio::net::TcpListener; + + // Persistent echo target. + let echo = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let echo_port = echo.local_addr().unwrap().port(); + tokio::spawn(async move { + while let Ok((mut s, _)) = echo.accept().await { + tokio::spawn(async move { + let (mut r, mut w) = s.split(); + let _ = tokio::io::copy(&mut r, &mut w).await; + }); + } + }); + + // Reserve a fixed UDP port for the relay so the second server can rebind it. + let probe = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let server_addr = probe.local_addr().unwrap(); + drop(probe); + + let uuid = Uuid::new_v4(); + + // Server #1 + client; confirm the proxy works. + let (ctx1, handle1) = spawn_server_on(server_addr, uuid).await; + let client = connect_client(server_addr, uuid).await.expect("connect client"); + let got = proxy_echo_once(&client, echo_port, b"before") + .await + .expect("initial proxied echo must succeed"); + assert_eq!(got, b"before"); + + // Drop server #1 and wait for the listen loop to release the UDP port. + ctx1.token.cancel(); + let _ = tokio::time::timeout(Duration::from_secs(5), handle1).await; + tokio::time::sleep(Duration::from_millis(300)).await; + + // Server #2 on the SAME address with the same credentials. + let (ctx2, _handle2) = spawn_server_on(server_addr, uuid).await; + + // The supervisor should reconnect within a few backoff cycles; retry the + // proxied echo until it succeeds (or give up after ~15s). + let mut reconnected = false; + for _ in 0..60 { + if let Ok(got) = proxy_echo_once(&client, echo_port, b"after").await + && got == b"after" + { + reconnected = true; + break; + } + tokio::time::sleep(Duration::from_millis(250)).await; + } + assert!(reconnected, "client did not reconnect and proxy after server restart"); + + ctx2.token.cancel(); + } + + /// Shutting the client down while its supervisor is stuck in the reconnect + /// backoff loop (server still down) must abandon reconnect and drain the + /// tracked tasks promptly, not hang. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_client_shuts_down_cleanly_while_reconnecting() { + let probe = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let server_addr = probe.local_addr().unwrap(); + drop(probe); + let uuid = Uuid::new_v4(); + + let (ctx1, handle1) = spawn_server_on(server_addr, uuid).await; + let client = connect_client(server_addr, uuid).await.expect("connect client"); + + // Kill the server so the supervisor enters its reconnect backoff loop. + ctx1.token.cancel(); + let _ = tokio::time::timeout(Duration::from_secs(5), handle1).await; + // Let the supervisor notice the drop and start retrying. + tokio::time::sleep(Duration::from_millis(500)).await; + + // Now shut the client down. The supervisor (sharing `client.ctx`) must + // stop reconnecting and its tracked tasks must drain. + client.ctx.token.cancel(); + client.ctx.tasks.close(); + tokio::time::timeout(Duration::from_secs(5), client.ctx.tasks.wait()) + .await + .expect("client tasks did not drain within 5s while reconnecting"); + } + + /// With reconnect disabled, a dropped connection is NOT re-established: + /// after the server restarts on the same address, proxied requests keep + /// failing. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_client_does_not_reconnect_when_disabled() { + use tokio::net::TcpListener; + + let echo = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let echo_port = echo.local_addr().unwrap().port(); + tokio::spawn(async move { + while let Ok((mut s, _)) = echo.accept().await { + tokio::spawn(async move { + let (mut r, mut w) = s.split(); + let _ = tokio::io::copy(&mut r, &mut w).await; + }); + } + }); + + let probe = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let server_addr = probe.local_addr().unwrap(); + drop(probe); + let uuid = Uuid::new_v4(); + + let (ctx1, handle1) = spawn_server_on(server_addr, uuid).await; + let client = connect_client_with( + server_addr, + uuid, + ReconnectConfig { + enabled: false, + ..Default::default() + }, + ) + .await + .expect("connect client"); + + // Works while the original connection is alive. + let got = proxy_echo_once(&client, echo_port, b"before") + .await + .expect("initial proxied echo must succeed"); + assert_eq!(got, b"before"); + + // Drop server #1, restart on the same address. + ctx1.token.cancel(); + let _ = tokio::time::timeout(Duration::from_secs(5), handle1).await; + tokio::time::sleep(Duration::from_millis(300)).await; + let (ctx2, _handle2) = spawn_server_on(server_addr, uuid).await; + + // Reconnect is disabled, so no proxied echo should ever succeed even + // though a fresh server is now listening. + let mut recovered = false; + for _ in 0..16 { + if proxy_echo_once(&client, echo_port, b"after").await.is_ok() { + recovered = true; + break; + } + tokio::time::sleep(Duration::from_millis(250)).await; + } + assert!(!recovered, "client reconnected despite reconnect being disabled"); + + ctx2.token.cancel(); + } } diff --git a/crates/wind-tuic/src/quinn/outbound.rs b/crates/wind-tuic/src/quinn/outbound.rs index 8c95fd9..ac32c89 100644 --- a/crates/wind-tuic/src/quinn/outbound.rs +++ b/crates/wind-tuic/src/quinn/outbound.rs @@ -4,6 +4,7 @@ use std::{ time::Duration, }; +use arc_swap::ArcSwap; use moka::future::Cache; use quinn::TokioRuntime; use quinn_congestions::bbr::BbrConfig; @@ -30,6 +31,33 @@ pub struct TuicOutboundOpts { pub gc_lifetime: Duration, pub skip_cert_verify: bool, pub alpn: Vec, + /// Automatic reconnect behaviour for the outbound connection. + pub reconnect: ReconnectConfig, +} + +/// Controls how the outbound supervisor re-establishes the QUIC connection +/// after it drops. +#[derive(Clone, Debug)] +pub struct ReconnectConfig { + /// When `false`, a dropped connection is not re-established — the + /// supervisor closes it and exits (the pre-reconnect behaviour). When + /// `true`, it retries with exponential backoff until it succeeds or the + /// client shuts down. + pub enabled: bool, + /// Delay before the first reconnect attempt; doubled after each failure. + pub initial_backoff: Duration, + /// Upper bound on the backoff delay. + pub max_backoff: Duration, +} + +impl Default for ReconnectConfig { + fn default() -> Self { + Self { + enabled: true, + initial_backoff: Duration::from_millis(500), + max_backoff: Duration::from_secs(30), + } + } } pub struct TuicOutbound { @@ -38,7 +66,10 @@ pub struct TuicOutbound { pub peer_addr: SocketAddr, pub sni: String, pub opts: TuicOutboundOpts, - pub connection: QuinnConnection, + /// The live QUIC connection, swappable so the reconnect supervisor can + /// replace it after a drop without callers holding a stale handle. Read + /// sites `load_full()` the current connection per operation. + pub connection: Arc>, pub udp_assoc_counter: AtomicU16, pub token: CancellationToken, pub udp_session: Cache>>, @@ -84,15 +115,10 @@ impl TuicOutbound { let endpoint = quinn::Endpoint::new(quinn::EndpointConfig::default(), None, socket, Arc::new(TokioRuntime))?; endpoint.set_default_client_config(client_config); - let raw_conn = endpoint - .connect(peer_addr, &server_name) - .map_err(|e| eyre::eyre!("Failed to connect to {} ({}): {}", peer_addr, server_name, e))? - .await?; - // Wrap in the backend-agnostic handle so the shared client/proto code - // (auth, heartbeat, TCP/UDP relay) drives it. - let connection = QuinnConnection::new(raw_conn); - connection.send_auth(&opts.auth.0, &opts.auth.1).await?; + // Establish and authenticate the initial connection. The reconnect + // supervisor reuses the same configured endpoint via `connect_and_auth`. + let connection = connect_and_auth(&endpoint, peer_addr, &server_name, &opts.auth).await?; Ok(Self { token: ctx.token.child_token(), @@ -101,152 +127,289 @@ impl TuicOutbound { peer_addr, sni: server_name, opts, - connection, + connection: Arc::new(ArcSwap::from_pointee(connection)), udp_assoc_counter: AtomicU16::new(0), udp_session: Cache::new(u16::MAX.into()), }) } pub async fn start_poll(&self) -> eyre::Result<()> { - let cancel_token = self.ctx.token.child_token(); - let connection = self.connection.clone(); + let shutdown = self.ctx.token.child_token(); + let connection_cell = self.connection.clone(); + let endpoint = self.endpoint.clone(); + let peer_addr = self.peer_addr; + let sni = self.sni.clone(); + let auth = self.opts.auth.clone(); + let heartbeat = self.opts.heartbeat; + let reconnect = self.opts.reconnect.clone(); + let ctx = self.ctx.clone(); let udp_session = self.udp_session.clone(); - let mut hb_interval = tokio::time::interval(self.opts.heartbeat); - const HEARTBEAT_MAX_FAILURES: usize = 3; - - let (datagram_rx, bi_rx, uni_rx) = self - .connection - .handle_incoming(self.ctx.clone(), cancel_token.clone()) - .await?; + // Connection supervisor: run one session (heartbeat + incoming handling) + // over the live connection until it drops or we shut down. On an + // unexpected drop, reconnect with backoff and swap the fresh connection + // into `connection_cell` so `handle_tcp` / `handle_udp` transparently use + // it. In-flight streams on the old connection are NOT resurrected — + // callers see them close and retry, getting new streams on the new + // connection. + let supervisor = async move { + loop { + let session_cancel = shutdown.child_token(); + let conn = connection_cell.load_full().as_ref().clone(); - let poll_task = async move { - let mut hb_failures = 0; - hb_interval.tick().await; + let end = run_session(&ctx, &conn, &udp_session, heartbeat, session_cancel.clone(), shutdown.clone()).await; + // Tear down this session's accept loops before reconnecting. + session_cancel.cancel(); - loop { - tokio::select! { - _ = cancel_token.cancelled() => { - info!(target: "tuic_out", "Heartbeat poll cancelled"); + match end { + Ok(SessionEnd::Shutdown) => { // Tell the server we are going away so it can reap the // connection immediately instead of waiting out its idle // timeout. - connection.close(0, b"client shutdown"); - return Ok(()); + conn.close(0, b"client shutdown"); + return eyre::Ok(()); } - _ = hb_interval.tick() => { - if let Err(e) = connection.send_heartbeat().await { - hb_failures += 1; - info!(target: "tuic_out", "Heartbeat failed ({}/{}): {}", hb_failures, HEARTBEAT_MAX_FAILURES, e); - - if hb_failures >= HEARTBEAT_MAX_FAILURES { - return Err(eyre::eyre!("Too many heartbeat failures ({}/{})", hb_failures, HEARTBEAT_MAX_FAILURES)); - } - } else if hb_failures > 0 { - info!(target: "tuic_out", "Heartbeat succeeded after {} failures", hb_failures); - hb_failures = 0; - } + Ok(SessionEnd::Lost) => { + warn!(target: "tuic_out", "Connection to {} lost; attempting to reconnect", peer_addr); } - Ok(_) = bi_rx.recv() => { - warn!(target: "tuic_out", "Received bi-directional stream on Outbound"); + Err(e) => { + warn!(target: "tuic_out", "Session ended with error: {e:?}; attempting to reconnect"); } - Ok(mut buf) = datagram_rx.recv() => { - info!(target: "tuic_out", "Received datagram: {} bytes", buf.len()); - use bytes::Buf; + } - let header = match crate::proto::decode_header(&mut buf, "datagram") { - Ok(h) => h, - Err(e) => { - warn!(target: "tuic_out", "Failed to decode header: {}", e); - continue; - } - }; + // A shutdown that raced the session drop: don't reconnect. + if shutdown.is_cancelled() { + conn.close(0, b"client shutdown"); + return eyre::Ok(()); + } + // Reconnect disabled: close and stop, mirroring the pre-reconnect + // behaviour where a dropped connection ended the poll task. + if !reconnect.enabled { + warn!(target: "tuic_out", "Reconnect disabled; connection to {} will not be re-established", peer_addr); + conn.close(0, b"client connection lost"); + return eyre::Ok(()); + } + // Abandon the dead connection explicitly so the server reaps it. + conn.close(0, b"reconnecting"); - let cmd = match crate::proto::decode_command(header.command, &mut buf, "datagram") { - Ok(c) => c, - Err(e) => { - warn!(target: "tuic_out", "Failed to decode command: {}", e); - continue; - } - }; + match reconnect_loop(&endpoint, peer_addr, &sni, &auth, &reconnect, &shutdown).await { + Some(new_conn) => { + connection_cell.store(Arc::new(new_conn)); + info!(target: "tuic_out", "Reconnected to {}", peer_addr); + } + // Cancelled while backing off. + None => return eyre::Ok(()), + } + } + }; + self.ctx.tasks.spawn(supervisor.in_current_span()); - if let crate::proto::Command::Packet { - assoc_id, - pkt_id, - frag_total, - frag_id, - size, - } = cmd { - let addr = match crate::proto::decode_address(&mut buf, "UDP packet") { - Ok(a) => a, - Err(e) => { - warn!(target: "tuic_out", "Failed to decode address: {}", e); - continue; - } - }; - - // Extract payload. `size` is attacker-controlled (it comes straight - // from the wire); `copy_to_bytes` panics when `size > buf.remaining()`, - // so a malicious peer could crash the outbound poll task by - // over-declaring it. Validate first and bail out cleanly instead. - let size = size as usize; - if buf.remaining() < size { - warn!( - target: "tuic_out", - "Packet command claims {} bytes of payload but only {} remain — dropping", - size, buf.remaining() - ); - continue; - } - let payload = buf.copy_to_bytes(size); - - let (target, has_address) = match crate::proto::address_to_target(addr) { - Ok(t) => (t, true), - Err(_) => { - (TargetAddr::IPv4(std::net::Ipv4Addr::UNSPECIFIED, 0), false) - } - }; - - if has_address { - info!(target: "tuic_out", "Received UDP packet: assoc={:#06x}, pkt={}, frag={}/{}, size={}, target={}", - assoc_id, pkt_id, frag_id + 1, frag_total, size, target); - } else { - info!(target: "tuic_out", "Received UDP fragment: assoc={:#06x}, pkt={}, frag={}/{}, size={} (no address - non-first fragment)", - assoc_id, pkt_id, frag_id + 1, frag_total, size); - } + Ok(()) + } +} - if let Some(tuic_udp_stream) = udp_session.get(&assoc_id).await { - let complete_packet = if frag_total > 1 { - tuic_udp_stream.process_fragment(assoc_id, pkt_id, frag_total, frag_id, payload, None, target).await - } else { - Some(wind_core::udp::UdpPacket { - source: None, - target, - payload, - }) - }; - - if let Some(packet) = complete_packet - && let Err(e) = tuic_udp_stream.receive_packet(packet).await { - warn!(target: "tuic_out", "Failed to send packet to UDP session {:#06x}: {}", assoc_id, e); - } - } else { - warn!(target: "tuic_out", "Received UDP packet for unknown association {:#06x}", assoc_id); - } - } else { - warn!(target: "tuic_out", "Received non-Packet command in datagram: {:?}", cmd); +/// Outcome of a single connection session. +enum SessionEnd { + /// The shutdown token fired — the supervisor should stop, not reconnect. + Shutdown, + /// The connection dropped (peer/transport close or heartbeat failures) — + /// the supervisor should reconnect. + Lost, +} + +/// Open a fresh QUIC connection on `endpoint` and complete the TUIC auth +/// handshake. Shared by initial connect ([`TuicOutbound::new`]) and reconnect. +async fn connect_and_auth( + endpoint: &quinn::Endpoint, + peer_addr: SocketAddr, + sni: &str, + auth: &(Uuid, Arc<[u8]>), +) -> Result { + let raw = endpoint + .connect(peer_addr, sni) + .map_err(|e| eyre::eyre!("Failed to connect to {} ({}): {}", peer_addr, sni, e))? + .await?; + // Wrap in the backend-agnostic handle so the shared client/proto code + // (auth, heartbeat, TCP/UDP relay) drives it. + let connection = QuinnConnection::new(raw); + connection.send_auth(&auth.0, &auth.1).await?; + Ok(connection) +} + +/// Next exponential-backoff delay: double `current`, capped at `max`. +fn next_backoff(current: Duration, max: Duration) -> Duration { + current.saturating_mul(2).min(max) +} + +/// Retry [`connect_and_auth`] with exponential backoff until it succeeds or +/// `shutdown` fires. Returns `None` if cancelled before a connection is made. +async fn reconnect_loop( + endpoint: &quinn::Endpoint, + peer_addr: SocketAddr, + sni: &str, + auth: &(Uuid, Arc<[u8]>), + reconnect: &ReconnectConfig, + shutdown: &CancellationToken, +) -> Option { + let mut backoff = reconnect.initial_backoff; + + loop { + // Race the connect attempt against shutdown so a hung handshake (e.g. + // the server is still down) doesn't delay a graceful exit. + let attempt = tokio::select! { + _ = shutdown.cancelled() => return None, + r = connect_and_auth(endpoint, peer_addr, sni, auth) => r, + }; + match attempt { + Ok(conn) => return Some(conn), + Err(e) => { + warn!(target: "tuic_out", "Reconnect to {} failed: {e}; retrying in {:?}", peer_addr, backoff); + tokio::select! { + _ = shutdown.cancelled() => return None, + _ = tokio::time::sleep(backoff) => {} + } + backoff = next_backoff(backoff, reconnect.max_backoff); + } + } + } +} + +/// Drive one connection's heartbeat and incoming-stream handling until the +/// connection drops, heartbeats fail repeatedly, or shutdown fires. `conn` is +/// the live connection; `session_cancel` scopes this session's accept loops. +async fn run_session( + ctx: &Arc, + conn: &QuinnConnection, + udp_session: &Cache>>, + heartbeat: Duration, + session_cancel: CancellationToken, + shutdown: CancellationToken, +) -> eyre::Result { + let (datagram_rx, bi_rx, uni_rx) = conn.handle_incoming(ctx.clone(), session_cancel).await?; + + let mut hb_interval = tokio::time::interval(heartbeat); + const HEARTBEAT_MAX_FAILURES: usize = 3; + let mut hb_failures = 0; + hb_interval.tick().await; + + loop { + tokio::select! { + _ = shutdown.cancelled() => { + info!(target: "tuic_out", "Heartbeat poll cancelled"); + return Ok(SessionEnd::Shutdown); + } + _ = conn.closed() => { + info!(target: "tuic_out", "Connection closed"); + return Ok(SessionEnd::Lost); + } + _ = hb_interval.tick() => { + if let Err(e) = conn.send_heartbeat().await { + hb_failures += 1; + info!(target: "tuic_out", "Heartbeat failed ({}/{}): {}", hb_failures, HEARTBEAT_MAX_FAILURES, e); + + if hb_failures >= HEARTBEAT_MAX_FAILURES { + return Ok(SessionEnd::Lost); + } + } else if hb_failures > 0 { + info!(target: "tuic_out", "Heartbeat succeeded after {} failures", hb_failures); + hb_failures = 0; + } + } + Ok(_) = bi_rx.recv() => { + warn!(target: "tuic_out", "Received bi-directional stream on Outbound"); + } + Ok(mut buf) = datagram_rx.recv() => { + info!(target: "tuic_out", "Received datagram: {} bytes", buf.len()); + use bytes::Buf; + + let header = match crate::proto::decode_header(&mut buf, "datagram") { + Ok(h) => h, + Err(e) => { + warn!(target: "tuic_out", "Failed to decode header: {}", e); + continue; + } + }; + + let cmd = match crate::proto::decode_command(header.command, &mut buf, "datagram") { + Ok(c) => c, + Err(e) => { + warn!(target: "tuic_out", "Failed to decode command: {}", e); + continue; + } + }; + + if let crate::proto::Command::Packet { + assoc_id, + pkt_id, + frag_total, + frag_id, + size, + } = cmd { + let addr = match crate::proto::decode_address(&mut buf, "UDP packet") { + Ok(a) => a, + Err(e) => { + warn!(target: "tuic_out", "Failed to decode address: {}", e); + continue; } + }; + + // Extract payload. `size` is attacker-controlled (it comes straight + // from the wire); `copy_to_bytes` panics when `size > buf.remaining()`, + // so a malicious peer could crash the outbound poll task by + // over-declaring it. Validate first and bail out cleanly instead. + let size = size as usize; + if buf.remaining() < size { + warn!( + target: "tuic_out", + "Packet command claims {} bytes of payload but only {} remain — dropping", + size, buf.remaining() + ); + continue; } + let payload = buf.copy_to_bytes(size); + + let (target, has_address) = match crate::proto::address_to_target(addr) { + Ok(t) => (t, true), + Err(_) => { + (TargetAddr::IPv4(std::net::Ipv4Addr::UNSPECIFIED, 0), false) + } + }; + + if has_address { + info!(target: "tuic_out", "Received UDP packet: assoc={:#06x}, pkt={}, frag={}/{}, size={}, target={}", + assoc_id, pkt_id, frag_id + 1, frag_total, size, target); + } else { + info!(target: "tuic_out", "Received UDP fragment: assoc={:#06x}, pkt={}, frag={}/{}, size={} (no address - non-first fragment)", + assoc_id, pkt_id, frag_id + 1, frag_total, size); + } + + if let Some(tuic_udp_stream) = udp_session.get(&assoc_id).await { + let complete_packet = if frag_total > 1 { + tuic_udp_stream.process_fragment(assoc_id, pkt_id, frag_total, frag_id, payload, None, target).await + } else { + Some(wind_core::udp::UdpPacket { + source: None, + target, + payload, + }) + }; - Ok(_recv) = uni_rx.recv() => { - info!(target: "tuic_out", "Received uni-directional stream"); + if let Some(packet) = complete_packet + && let Err(e) = tuic_udp_stream.receive_packet(packet).await { + warn!(target: "tuic_out", "Failed to send packet to UDP session {:#06x}: {}", assoc_id, e); + } + } else { + warn!(target: "tuic_out", "Received UDP packet for unknown association {:#06x}", assoc_id); } + } else { + warn!(target: "tuic_out", "Received non-Packet command in datagram: {:?}", cmd); } } - }; - self.ctx.tasks.spawn(poll_task.in_current_span()); - Ok(()) + Ok(_recv) = uni_rx.recv() => { + info!(target: "tuic_out", "Received uni-directional stream"); + } + } } } @@ -259,7 +422,8 @@ impl AbstractOutbound for TuicOutbound { stream: impl AbstractTcpStream, _dialer: Option, ) -> eyre::Result<()> { - self.connection.open_tcp(&target_addr, stream).await?; + let connection = self.connection.load_full(); + connection.open_tcp(&target_addr, stream).await?; Ok(()) } @@ -294,7 +458,9 @@ impl AbstractOutbound for TuicOutbound { }; info!(target: "tuic_out", "Creating new UDP association: {:#06x}", assoc_id); - let connection = self.connection.clone(); + // Snapshot the live connection for this association. If a reconnect swaps + // the connection later, this session's streams die and the caller retries. + let connection = self.connection.load_full().as_ref().clone(); let (receive_tx, receive_rx) = crossfire::mpmc::bounded_async(256); let tuic_stream = Arc::new(crate::proto::UdpStream::new(connection.clone(), assoc_id, receive_tx)); self.udp_session.insert(assoc_id, tuic_stream.clone()).await; @@ -367,10 +533,46 @@ impl AbstractOutbound for TuicOutbound { } } - if let Err(err) = self.connection.drop_udp(assoc_id).await { + let connection = self.connection.load_full(); + if let Err(err) = connection.drop_udp(assoc_id).await { info!(target: "tuic_out", "Error dropping UDP association {:#06x}: {}", assoc_id, err); } Ok(()) } } + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::{ReconnectConfig, next_backoff}; + + #[test] + fn next_backoff_doubles_until_capped() { + let max = Duration::from_secs(30); + assert_eq!(next_backoff(Duration::from_millis(500), max), Duration::from_secs(1)); + assert_eq!(next_backoff(Duration::from_secs(1), max), Duration::from_secs(2)); + assert_eq!(next_backoff(Duration::from_secs(16), max), max); + // Already at/over the cap stays capped. + assert_eq!(next_backoff(max, max), max); + assert_eq!(next_backoff(Duration::from_secs(60), max), max); + } + + #[test] + fn next_backoff_does_not_overflow() { + // Doubling near Duration::MAX must saturate, not panic. + let huge = Duration::from_secs(u64::MAX / 2 + 1); + let max = Duration::from_secs(u64::MAX); + assert_eq!(next_backoff(huge, max), max); + } + + #[test] + fn reconnect_config_default_is_enabled_with_sane_bounds() { + let cfg = ReconnectConfig::default(); + assert!(cfg.enabled); + assert_eq!(cfg.initial_backoff, Duration::from_millis(500)); + assert_eq!(cfg.max_backoff, Duration::from_secs(30)); + assert!(cfg.initial_backoff <= cfg.max_backoff); + } +} diff --git a/crates/wind-tuic/src/quinn/tls.rs b/crates/wind-tuic/src/quinn/tls.rs index 0a3e507..2d9885c 100644 --- a/crates/wind-tuic/src/quinn/tls.rs +++ b/crates/wind-tuic/src/quinn/tls.rs @@ -123,6 +123,7 @@ mod tests { // between both branches. skip_cert_verify: true, alpn, + reconnect: crate::quinn::outbound::ReconnectConfig::default(), } } diff --git a/crates/wind-tuic/src/server/mod.rs b/crates/wind-tuic/src/server/mod.rs index a9d75b9..0da3ed9 100644 --- a/crates/wind-tuic/src/server/mod.rs +++ b/crates/wind-tuic/src/server/mod.rs @@ -935,3 +935,206 @@ async fn handle_dissociate(connection: &InboundCtx, assoc_ info!("Dissociated UDP session {}", assoc_id); Ok(()) } + +#[cfg(test)] +mod tests { + use std::sync::atomic::AtomicUsize; + + // Brings in Arc, Duration, Ordering, CancellationToken, QuicError, CmdType, + // and the private helpers under test (`acceptor_loop`, `is_tuic_prefix`, + // `read_prefix`). + use super::*; + + /// Cancellation must interrupt an accept that is parked forever. This is + /// the core of the graceful-shutdown chain: every per-connection acceptor + /// loop is blocked in `accept()` when shutdown fires, and must unstick + /// promptly without handling another item. + #[tokio::test] + async fn acceptor_loop_exits_when_cancelled_mid_accept() { + let cancel = CancellationToken::new(); + let handled = Arc::new(AtomicUsize::new(0)); + let h = handled.clone(); + let loop_cancel = cancel.clone(); + + let task = tokio::spawn(async move { + acceptor_loop( + loop_cancel, + "test-mid-accept", + // Never resolves: only the cancel branch can complete the loop. + std::future::pending::>, + move |_item: ()| { + let h = h.clone(); + async move { + h.fetch_add(1, Ordering::SeqCst); + } + }, + ) + .await; + }); + + // Let the loop reach its `select!` and park on `accept()`. + tokio::task::yield_now().await; + cancel.cancel(); + + tokio::time::timeout(Duration::from_secs(1), task) + .await + .expect("acceptor_loop did not exit within 1s of cancellation") + .expect("acceptor_loop task panicked"); + + assert_eq!( + handled.load(Ordering::SeqCst), + 0, + "no item should be handled when accept never resolves" + ); + } + + /// A loop that is cancelled before it ever runs must exit without spinning. + #[tokio::test] + async fn acceptor_loop_exits_when_already_cancelled() { + let cancel = CancellationToken::new(); + cancel.cancel(); + let handled = Arc::new(AtomicUsize::new(0)); + let h = handled.clone(); + + tokio::time::timeout( + Duration::from_secs(1), + acceptor_loop( + cancel, + "test-pre-cancelled", + std::future::pending::>, + move |_item: ()| { + let h = h.clone(); + async move { + h.fetch_add(1, Ordering::SeqCst); + } + }, + ), + ) + .await + .expect("acceptor_loop did not exit promptly when pre-cancelled"); + + assert_eq!(handled.load(Ordering::SeqCst), 0); + } + + /// Items accepted before a benign connection close are handled, then the + /// loop returns (it does not treat `LocallyClosed` as a fatal error nor + /// spin). + #[tokio::test] + async fn acceptor_loop_handles_items_then_exits_on_benign_close() { + let cancel = CancellationToken::new(); + let handled = Arc::new(AtomicUsize::new(0)); + let calls = Arc::new(AtomicUsize::new(0)); + let h = handled.clone(); + let c = calls.clone(); + + tokio::time::timeout( + Duration::from_secs(1), + acceptor_loop( + cancel, + "test-benign-close", + move || { + let n = c.fetch_add(1, Ordering::SeqCst); + async move { if n < 3 { Ok(()) } else { Err(QuicError::LocallyClosed) } } + }, + move |_item: ()| { + let h = h.clone(); + async move { + h.fetch_add(1, Ordering::SeqCst); + } + }, + ), + ) + .await + .expect("acceptor_loop did not terminate after a benign close"); + + assert_eq!( + handled.load(Ordering::SeqCst), + 3, + "the three Ok items must be handled before the close ends the loop" + ); + } + + /// `TimedOut` (idle timeout) is a benign lifecycle close: the loop returns. + #[tokio::test] + async fn acceptor_loop_exits_on_timed_out() { + let cancel = CancellationToken::new(); + let handled = Arc::new(AtomicUsize::new(0)); + let h = handled.clone(); + + tokio::time::timeout( + Duration::from_secs(1), + acceptor_loop( + cancel, + "test-timed-out", + || async { Err::<(), _>(QuicError::TimedOut) }, + move |_item: ()| { + let h = h.clone(); + async move { + h.fetch_add(1, Ordering::SeqCst); + } + }, + ), + ) + .await + .expect("acceptor_loop did not terminate on TimedOut"); + + assert_eq!(handled.load(Ordering::SeqCst), 0); + } + + /// A non-benign error (e.g. connection lost) also ends the loop rather than + /// retrying forever. + #[tokio::test] + async fn acceptor_loop_exits_on_fatal_error() { + let cancel = CancellationToken::new(); + let handled = Arc::new(AtomicUsize::new(0)); + let h = handled.clone(); + + tokio::time::timeout( + Duration::from_secs(1), + acceptor_loop( + cancel, + "test-fatal", + || async { Err::<(), _>(QuicError::ConnectionLost("boom".into())) }, + move |_item: ()| { + let h = h.clone(); + async move { + h.fetch_add(1, Ordering::SeqCst); + } + }, + ), + ) + .await + .expect("acceptor_loop did not terminate on a fatal error"); + + assert_eq!(handled.load(Ordering::SeqCst), 0); + } + + /// The 2-byte classifier must accept only `[VER, CmdType]` framing and + /// reject anything an HTTP/3 stream would start with. + #[test] + fn is_tuic_prefix_distinguishes_tuic_from_h3() { + let auth = u8::from(CmdType::Auth); + let heartbeat = u8::from(CmdType::Heartbeat); + + assert!(is_tuic_prefix([crate::proto::VER, auth])); + assert!(is_tuic_prefix([crate::proto::VER, heartbeat])); + // CmdType byte just past the valid range (Auth..=Heartbeat). + assert!(!is_tuic_prefix([crate::proto::VER, heartbeat + 1])); + // Correct command byte but wrong version byte. + assert!(!is_tuic_prefix([crate::proto::VER.wrapping_add(1), auth])); + } + + /// `read_prefix` yields the first two bytes, or `None` if the stream closes + /// before two bytes arrive. + #[tokio::test] + async fn read_prefix_returns_two_bytes_or_none() { + let mut full: &[u8] = &[0x05, 0x00, 0x42]; + assert_eq!(read_prefix(&mut full).await, Some([0x05, 0x00])); + + let mut short: &[u8] = &[0x05]; + assert_eq!(read_prefix(&mut short).await, None); + + let mut empty: &[u8] = &[]; + assert_eq!(read_prefix(&mut empty).await, None); + } +} diff --git a/crates/wind/src/conf/runtime.rs b/crates/wind/src/conf/runtime.rs index 9283da3..9d75526 100644 --- a/crates/wind/src/conf/runtime.rs +++ b/crates/wind/src/conf/runtime.rs @@ -64,6 +64,7 @@ impl OutboundRuntime { gc_lifetime: std::time::Duration::from_secs(t.gc_lifetime_secs), skip_cert_verify: t.skip_cert_verify, alpn: t.alpn.clone(), + reconnect: wind_tuic::quinn::outbound::ReconnectConfig::default(), }), } } From 7b0086e492ab7afdd88741499db8ee65c6c7a5e9 Mon Sep 17 00:00:00 2001 From: iHsin Date: Mon, 15 Jun 2026 21:33:17 +0800 Subject: [PATCH 2/2] test(socks): reword doc comment to satisfy typos check 'CONNECTed' tokenizes as 'CONNEC'+'Ted'; the typos CI flagged 'CONNEC'. Co-Authored-By: Claude Opus 4.8 (1M context) --- crates/wind-socks/tests/socks_tcp.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/wind-socks/tests/socks_tcp.rs b/crates/wind-socks/tests/socks_tcp.rs index 767de73..65b08f4 100644 --- a/crates/wind-socks/tests/socks_tcp.rs +++ b/crates/wind-socks/tests/socks_tcp.rs @@ -140,7 +140,7 @@ fn domain_body(host: &str, port: u16) -> Vec { b } -/// Assert a CONNECTed stream echoes a payload back unchanged. +/// Assert a connected stream echoes a payload back unchanged. async fn assert_echo_roundtrip(s: &mut TcpStream) { let msg = b"hello socks5"; s.write_all(msg).await.unwrap();