Compare commits

...

26 Commits
3.4.4 ... 3.4.7

Author SHA1 Message Date
Alexey
4e57cee9b9 Merge pull request #745 from telemt/flow
API PATCH fixes + No IP tracking with disabled unique-IP limits + Bound hot-path pressure in ME Relay and Handshake + Bounded ME Route fairness and IP-Cleanup-Backlog + Bound relay queues by bytes
2026-04-25 14:45:34 +03:00
Alexey
e217371dc8 Bump 2026-04-25 14:36:51 +03:00
Alexey
37c916056a Rustfmt 2026-04-25 14:35:35 +03:00
Alexey
2f2fe9d5d3 Bound relay queues by bytes
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
Signed-off-by: Alexey <247128645+axkurcom@users.noreply.github.com>
2026-04-25 13:54:20 +03:00
Alexey
1df668144c Bounded ME Route fairness and IP-Cleanup-Backlog
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
Signed-off-by: Alexey <247128645+axkurcom@users.noreply.github.com>
2026-04-25 13:09:10 +03:00
Alexey
8494429690 Merge pull request #743 from amirotin/api/patch-user-null-removal
feat(api): support null-removal in PATCH /v1/users/{user}
2026-04-25 13:07:13 +03:00
Alexey
f25bb17b86 Merge branch 'flow' into api/patch-user-null-removal 2026-04-25 12:28:48 +03:00
Alexey
27b5d576c0 Bound hot-path pressure in ME Relay + Handshake
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
Signed-off-by: Alexey <247128645+axkurcom@users.noreply.github.com>
2026-04-25 12:16:26 +03:00
Alexey
e78592ef9b Avoid IP tracking when unique-IP limits are disabled and cap beobachten memory
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
Signed-off-by: Alexey <247128645+axkurcom@users.noreply.github.com>
2026-04-25 12:00:46 +03:00
Mirotin Artem
4ed87d1946 feat(api): support null-removal in PATCH /v1/users/{user}
PatchUserRequest now uses Patch<T> for the five removable fields
(user_ad_tag, max_tcp_conns, expiration_rfc3339, data_quota_bytes,
max_unique_ips). Sending JSON null drops the entry from the
corresponding access HashMap; sending 0 is preserved as a literal
limit; omitted fields stay untouched. The handler synchronises the
in-memory ip_tracker on both set and remove of max_unique_ips. A
helper parse_patch_expiration mirrors parse_optional_expiration for
the new three-state field. Runtime semantics are unchanged.
2026-04-25 00:49:34 +03:00
Mirotin Artem
635bea4de4 feat(api): add Patch<T> enum for JSON merge-patch semantics
Introduce a three-state Patch<T> (Unchanged / Remove / Set) and a
serde helper patch_field that distinguishes an omitted JSON field
from an explicit null. Wired up next as the field type for the
removable settings on PATCH /v1/users/{user}.
2026-04-25 00:49:34 +03:00
Alexey
8874396ba5 Merge pull request #739 from telemt/flow-test
Relays Tests Fixes
2026-04-24 15:51:47 +03:00
Alexey
033ebf5038 Relays Tests Fixes 2026-04-24 15:51:19 +03:00
Alexey
f7b918875c Close Errors Classification + TLS 1.2/1.3 Correctness in Fronting + Full ServerHello + ALPN in TLS Fetcher: merge pull request #738 from telemt/flow
Close Errors Classification + TLS 1.2/1.3 Correctness in Fronting + Full ServerHello + ALPN in TLS Fetcher
2026-04-24 15:48:39 +03:00
Alexey
8960fad8cd Сlassified Bad Connections and Handshake Failures in API
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-04-24 10:56:30 +03:00
Alexey
493f5c9680 ALPN in TLS Fetcher
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-04-23 22:22:05 +03:00
Alexey
67357310f7 TLS 1.2/1.3 Correctness + Full ServerHello + Rustfmt
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-04-23 21:29:18 +03:00
Alexey
8684378030 Human-readable Peer Close Classification 2026-04-21 15:46:18 +03:00
Alexey
db8d333ed6 Noisy-network peer Close Errors Classification 2026-04-21 15:35:11 +03:00
Alexey
30e73adaac Bump 2026-04-21 13:38:38 +03:00
Alexey
351f2c8458 Fairness Regression fixes + Unlimited mask_relay_max_bytes: merge pull request #726 from telemt/flow
Fairness Regression fixes + Unlimited mask_relay_max_bytes
2026-04-21 13:37:10 +03:00
Alexey
4ce6b14bd8 Rustfmt 2026-04-21 13:31:24 +03:00
Alexey
db114f09c3 Sync tests with code 2026-04-21 13:30:11 +03:00
Alexey
09310ff284 Unlimited mask_relay_max_bytes 2026-04-21 11:30:58 +03:00
Alexey
1e5b84c0ed Fairshare Disabled semantics fix 2026-04-21 11:21:58 +03:00
Alexey
926e3aa987 Fairness Regression fixes 2026-04-21 01:11:43 +03:00
52 changed files with 2131 additions and 346 deletions

2
Cargo.lock generated
View File

@@ -2791,7 +2791,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
[[package]]
name = "telemt"
version = "3.4.4"
version = "3.4.7"
dependencies = [
"aes",
"anyhow",

View File

@@ -1,6 +1,6 @@
[package]
name = "telemt"
version = "3.4.4"
version = "3.4.7"
edition = "2024"
[features]

View File

@@ -28,6 +28,7 @@ mod config_store;
mod events;
mod http_utils;
mod model;
mod patch;
mod runtime_edge;
mod runtime_init;
mod runtime_min;
@@ -41,7 +42,7 @@ use config_store::{current_revision, load_config_from_disk, parse_if_match};
use events::ApiEventStore;
use http_utils::{error_response, read_json, read_optional_json, success_response};
use model::{
ApiFailure, CreateUserRequest, DeleteUserResponse, HealthData, HealthReadyData,
ApiFailure, ClassCount, CreateUserRequest, DeleteUserResponse, HealthData, HealthReadyData,
PatchUserRequest, RotateSecretRequest, SummaryData, UserActiveIps,
};
use runtime_edge::{
@@ -334,10 +335,24 @@ async fn handle(
}
("GET", "/v1/stats/summary") => {
let revision = current_revision(&shared.config_path).await?;
let connections_bad_by_class = shared
.stats
.get_connects_bad_class_counts()
.into_iter()
.map(|(class, total)| ClassCount { class, total })
.collect();
let handshake_failures_by_class = shared
.stats
.get_handshake_failure_class_counts()
.into_iter()
.map(|(class, total)| ClassCount { class, total })
.collect();
let data = SummaryData {
uptime_seconds: shared.stats.uptime_secs(),
connections_total: shared.stats.get_connects_all(),
connections_bad_total: shared.stats.get_connects_bad(),
connections_bad_by_class,
handshake_failures_by_class,
handshake_timeouts_total: shared.stats.get_handshake_timeouts(),
configured_users: cfg.access.users.len(),
};

View File

@@ -5,6 +5,7 @@ use chrono::{DateTime, Utc};
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use super::patch::{Patch, patch_field};
use crate::crypto::SecureRandom;
const MAX_USERNAME_LEN: usize = 64;
@@ -71,11 +72,19 @@ pub(super) struct HealthReadyData {
pub(super) total_upstreams: usize,
}
#[derive(Serialize, Clone)]
pub(super) struct ClassCount {
pub(super) class: String,
pub(super) total: u64,
}
#[derive(Serialize)]
pub(super) struct SummaryData {
pub(super) uptime_seconds: f64,
pub(super) connections_total: u64,
pub(super) connections_bad_total: u64,
pub(super) connections_bad_by_class: Vec<ClassCount>,
pub(super) handshake_failures_by_class: Vec<ClassCount>,
pub(super) handshake_timeouts_total: u64,
pub(super) configured_users: usize,
}
@@ -91,6 +100,8 @@ pub(super) struct ZeroCoreData {
pub(super) uptime_seconds: f64,
pub(super) connections_total: u64,
pub(super) connections_bad_total: u64,
pub(super) connections_bad_by_class: Vec<ClassCount>,
pub(super) handshake_failures_by_class: Vec<ClassCount>,
pub(super) handshake_timeouts_total: u64,
pub(super) accept_permit_timeout_total: u64,
pub(super) configured_users: usize,
@@ -497,11 +508,16 @@ pub(super) struct CreateUserRequest {
#[derive(Deserialize)]
pub(super) struct PatchUserRequest {
pub(super) secret: Option<String>,
pub(super) user_ad_tag: Option<String>,
pub(super) max_tcp_conns: Option<usize>,
pub(super) expiration_rfc3339: Option<String>,
pub(super) data_quota_bytes: Option<u64>,
pub(super) max_unique_ips: Option<usize>,
#[serde(default, deserialize_with = "patch_field")]
pub(super) user_ad_tag: Patch<String>,
#[serde(default, deserialize_with = "patch_field")]
pub(super) max_tcp_conns: Patch<usize>,
#[serde(default, deserialize_with = "patch_field")]
pub(super) expiration_rfc3339: Patch<String>,
#[serde(default, deserialize_with = "patch_field")]
pub(super) data_quota_bytes: Patch<u64>,
#[serde(default, deserialize_with = "patch_field")]
pub(super) max_unique_ips: Patch<usize>,
}
#[derive(Default, Deserialize)]
@@ -520,6 +536,20 @@ pub(super) fn parse_optional_expiration(
Ok(Some(parsed.with_timezone(&Utc)))
}
pub(super) fn parse_patch_expiration(
value: &Patch<String>,
) -> Result<Patch<DateTime<Utc>>, ApiFailure> {
match value {
Patch::Unchanged => Ok(Patch::Unchanged),
Patch::Remove => Ok(Patch::Remove),
Patch::Set(raw) => {
let parsed = DateTime::parse_from_rfc3339(raw)
.map_err(|_| ApiFailure::bad_request("expiration_rfc3339 must be valid RFC3339"))?;
Ok(Patch::Set(parsed.with_timezone(&Utc)))
}
}
}
pub(super) fn is_valid_user_secret(secret: &str) -> bool {
secret.len() == 32 && secret.chars().all(|c| c.is_ascii_hexdigit())
}

130
src/api/patch.rs Normal file
View File

@@ -0,0 +1,130 @@
use serde::Deserialize;
/// Three-state field for JSON Merge Patch semantics on the `PATCH /v1/users/{user}`
/// endpoint.
///
/// `Unchanged` is produced when the JSON body omits the field entirely and tells the
/// handler to leave the corresponding configuration entry untouched. `Remove` is
/// produced when the JSON body sets the field to `null` and instructs the handler to
/// drop the entry from the corresponding access HashMap. `Set` carries an explicit
/// new value, including zero, which is preserved verbatim in the configuration.
#[derive(Debug)]
pub(super) enum Patch<T> {
Unchanged,
Remove,
Set(T),
}
impl<T> Default for Patch<T> {
fn default() -> Self {
Self::Unchanged
}
}
/// Serde deserializer adapter for fields that follow JSON Merge Patch semantics.
///
/// Pair this with `#[serde(default, deserialize_with = "patch_field")]` on a
/// `Patch<T>` field. An omitted field falls back to `Patch::Unchanged` via
/// `Default`; an explicit JSON `null` becomes `Patch::Remove`; any other value
/// becomes `Patch::Set(v)`.
pub(super) fn patch_field<'de, D, T>(deserializer: D) -> Result<Patch<T>, D::Error>
where
D: serde::Deserializer<'de>,
T: serde::Deserialize<'de>,
{
Option::<T>::deserialize(deserializer).map(|opt| match opt {
Some(value) => Patch::Set(value),
None => Patch::Remove,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::model::{PatchUserRequest, parse_patch_expiration};
use chrono::{TimeZone, Utc};
use serde::Deserialize;
#[derive(Deserialize)]
struct Holder {
#[serde(default, deserialize_with = "patch_field")]
value: Patch<u64>,
}
fn parse(json: &str) -> Holder {
serde_json::from_str(json).expect("valid json")
}
#[test]
fn omitted_field_yields_unchanged() {
let h = parse("{}");
assert!(matches!(h.value, Patch::Unchanged));
}
#[test]
fn explicit_null_yields_remove() {
let h = parse(r#"{"value": null}"#);
assert!(matches!(h.value, Patch::Remove));
}
#[test]
fn explicit_value_yields_set() {
let h = parse(r#"{"value": 42}"#);
assert!(matches!(h.value, Patch::Set(42)));
}
#[test]
fn explicit_zero_yields_set_zero() {
let h = parse(r#"{"value": 0}"#);
assert!(matches!(h.value, Patch::Set(0)));
}
#[test]
fn parse_patch_expiration_passes_unchanged_and_remove_through() {
assert!(matches!(
parse_patch_expiration(&Patch::Unchanged),
Ok(Patch::Unchanged)
));
assert!(matches!(
parse_patch_expiration(&Patch::Remove),
Ok(Patch::Remove)
));
}
#[test]
fn parse_patch_expiration_parses_set_value() {
let parsed =
parse_patch_expiration(&Patch::Set("2030-01-02T03:04:05Z".into())).expect("valid");
match parsed {
Patch::Set(dt) => {
assert_eq!(dt, Utc.with_ymd_and_hms(2030, 1, 2, 3, 4, 5).unwrap());
}
other => panic!("expected Patch::Set, got {:?}", other),
}
}
#[test]
fn parse_patch_expiration_rejects_invalid_set_value() {
assert!(parse_patch_expiration(&Patch::Set("not-a-date".into())).is_err());
}
#[test]
fn patch_user_request_deserializes_mixed_states() {
let raw = r#"{
"secret": "00112233445566778899aabbccddeeff",
"max_tcp_conns": 0,
"max_unique_ips": null,
"data_quota_bytes": 1024
}"#;
let req: PatchUserRequest = serde_json::from_str(raw).expect("valid json");
assert_eq!(
req.secret.as_deref(),
Some("00112233445566778899aabbccddeeff")
);
assert!(matches!(req.max_tcp_conns, Patch::Set(0)));
assert!(matches!(req.max_unique_ips, Patch::Remove));
assert!(matches!(req.data_quota_bytes, Patch::Set(1024)));
assert!(matches!(req.expiration_rfc3339, Patch::Unchanged));
assert!(matches!(req.user_ad_tag, Patch::Unchanged));
}
}

View File

@@ -7,8 +7,8 @@ use crate::transport::upstream::IpPreference;
use super::ApiShared;
use super::model::{
DcEndpointWriters, DcStatus, DcStatusData, MeWriterStatus, MeWritersData, MeWritersSummary,
MinimalAllData, MinimalAllPayload, MinimalDcPathData, MinimalMeRuntimeData,
ClassCount, DcEndpointWriters, DcStatus, DcStatusData, MeWriterStatus, MeWritersData,
MeWritersSummary, MinimalAllData, MinimalAllPayload, MinimalDcPathData, MinimalMeRuntimeData,
MinimalQuarantineData, UpstreamDcStatus, UpstreamStatus, UpstreamSummaryData, UpstreamsData,
ZeroAllData, ZeroCodeCount, ZeroCoreData, ZeroDesyncData, ZeroMiddleProxyData, ZeroPoolData,
ZeroUpstreamData,
@@ -26,6 +26,16 @@ pub(crate) struct MinimalCacheEntry {
pub(super) fn build_zero_all_data(stats: &Stats, configured_users: usize) -> ZeroAllData {
let telemetry = stats.telemetry_policy();
let bad_connection_classes = stats
.get_connects_bad_class_counts()
.into_iter()
.map(|(class, total)| ClassCount { class, total })
.collect();
let handshake_failure_classes = stats
.get_handshake_failure_class_counts()
.into_iter()
.map(|(class, total)| ClassCount { class, total })
.collect();
let handshake_error_codes = stats
.get_me_handshake_error_code_counts()
.into_iter()
@@ -38,6 +48,8 @@ pub(super) fn build_zero_all_data(stats: &Stats, configured_users: usize) -> Zer
uptime_seconds: stats.uptime_secs(),
connections_total: stats.get_connects_all(),
connections_bad_total: stats.get_connects_bad(),
connections_bad_by_class: bad_connection_classes,
handshake_failures_by_class: handshake_failure_classes,
handshake_timeouts_total: stats.get_handshake_timeouts(),
accept_permit_timeout_total: stats.get_accept_permit_timeout_total(),
configured_users,

View File

@@ -14,8 +14,9 @@ use super::config_store::{
use super::model::{
ApiFailure, CreateUserRequest, CreateUserResponse, PatchUserRequest, RotateSecretRequest,
UserInfo, UserLinks, is_valid_ad_tag, is_valid_user_secret, is_valid_username,
parse_optional_expiration, random_user_secret,
parse_optional_expiration, parse_patch_expiration, random_user_secret,
};
use super::patch::Patch;
pub(super) async fn create_user(
body: CreateUserRequest,
@@ -182,14 +183,14 @@ pub(super) async fn patch_user(
"secret must be exactly 32 hex characters",
));
}
if let Some(ad_tag) = body.user_ad_tag.as_ref()
if let Patch::Set(ad_tag) = &body.user_ad_tag
&& !is_valid_ad_tag(ad_tag)
{
return Err(ApiFailure::bad_request(
"user_ad_tag must be exactly 32 hex characters",
));
}
let expiration = parse_optional_expiration(body.expiration_rfc3339.as_deref())?;
let expiration = parse_patch_expiration(&body.expiration_rfc3339)?;
let _guard = shared.mutation_lock.lock().await;
let mut cfg = load_config_from_disk(&shared.config_path).await?;
ensure_expected_revision(&shared.config_path, expected_revision.as_deref()).await?;
@@ -205,38 +206,71 @@ pub(super) async fn patch_user(
if let Some(secret) = body.secret {
cfg.access.users.insert(user.to_string(), secret);
}
if let Some(ad_tag) = body.user_ad_tag {
cfg.access.user_ad_tags.insert(user.to_string(), ad_tag);
match body.user_ad_tag {
Patch::Unchanged => {}
Patch::Remove => {
cfg.access.user_ad_tags.remove(user);
}
Patch::Set(ad_tag) => {
cfg.access.user_ad_tags.insert(user.to_string(), ad_tag);
}
}
if let Some(limit) = body.max_tcp_conns {
cfg.access
.user_max_tcp_conns
.insert(user.to_string(), limit);
match body.max_tcp_conns {
Patch::Unchanged => {}
Patch::Remove => {
cfg.access.user_max_tcp_conns.remove(user);
}
Patch::Set(limit) => {
cfg.access
.user_max_tcp_conns
.insert(user.to_string(), limit);
}
}
if let Some(expiration) = expiration {
cfg.access
.user_expirations
.insert(user.to_string(), expiration);
match expiration {
Patch::Unchanged => {}
Patch::Remove => {
cfg.access.user_expirations.remove(user);
}
Patch::Set(expiration) => {
cfg.access
.user_expirations
.insert(user.to_string(), expiration);
}
}
if let Some(quota) = body.data_quota_bytes {
cfg.access.user_data_quota.insert(user.to_string(), quota);
}
let mut updated_limit = None;
if let Some(limit) = body.max_unique_ips {
cfg.access
.user_max_unique_ips
.insert(user.to_string(), limit);
updated_limit = Some(limit);
match body.data_quota_bytes {
Patch::Unchanged => {}
Patch::Remove => {
cfg.access.user_data_quota.remove(user);
}
Patch::Set(quota) => {
cfg.access.user_data_quota.insert(user.to_string(), quota);
}
}
// Capture how the per-user IP limit changed, so the in-memory ip_tracker
// can be synced (set or removed) after the config is persisted.
let max_unique_ips_change = match body.max_unique_ips {
Patch::Unchanged => None,
Patch::Remove => {
cfg.access.user_max_unique_ips.remove(user);
Some(None)
}
Patch::Set(limit) => {
cfg.access
.user_max_unique_ips
.insert(user.to_string(), limit);
Some(Some(limit))
}
};
cfg.validate()
.map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?;
let revision = save_config_to_disk(&shared.config_path, &cfg).await?;
drop(_guard);
if let Some(limit) = updated_limit {
shared.ip_tracker.set_user_limit(user, limit).await;
match max_unique_ips_change {
Some(Some(limit)) => shared.ip_tracker.set_user_limit(user, limit).await,
Some(None) => shared.ip_tracker.remove_user_limit(user).await,
None => {}
}
let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips();
let users = users_from_config(

View File

@@ -689,6 +689,7 @@ tls_domain = "{domain}"
mask = true
mask_port = 443
fake_cert_len = 2048
serverhello_compact = false
tls_full_cert_ttl_secs = 90
[access]

View File

@@ -21,6 +21,8 @@ const DEFAULT_ME_ADAPTIVE_FLOOR_MAX_ACTIVE_WRITERS_PER_CORE: u16 = 64;
const DEFAULT_ME_ADAPTIVE_FLOOR_MAX_WARM_WRITERS_PER_CORE: u16 = 64;
const DEFAULT_ME_ADAPTIVE_FLOOR_MAX_ACTIVE_WRITERS_GLOBAL: u32 = 256;
const DEFAULT_ME_ADAPTIVE_FLOOR_MAX_WARM_WRITERS_GLOBAL: u32 = 256;
const DEFAULT_ME_ROUTE_BACKPRESSURE_ENABLED: bool = false;
const DEFAULT_ME_ROUTE_FAIRSHARE_ENABLED: bool = false;
const DEFAULT_ME_WRITER_CMD_CHANNEL_CAPACITY: usize = 4096;
const DEFAULT_ME_ROUTE_CHANNEL_CAPACITY: usize = 768;
const DEFAULT_ME_C2ME_CHANNEL_CAPACITY: usize = 1024;
@@ -529,6 +531,14 @@ pub(crate) fn default_me_route_backpressure_base_timeout_ms() -> u64 {
25
}
pub(crate) fn default_me_route_backpressure_enabled() -> bool {
DEFAULT_ME_ROUTE_BACKPRESSURE_ENABLED
}
pub(crate) fn default_me_route_fairshare_enabled() -> bool {
DEFAULT_ME_ROUTE_FAIRSHARE_ENABLED
}
pub(crate) fn default_me_route_backpressure_high_timeout_ms() -> u64 {
120
}
@@ -565,6 +575,10 @@ pub(crate) fn default_tls_new_session_tickets() -> u8 {
0
}
pub(crate) fn default_serverhello_compact() -> bool {
false
}
pub(crate) fn default_tls_full_cert_ttl_secs() -> u64 {
90
}

View File

@@ -86,6 +86,8 @@ pub struct HotFields {
pub telemetry_user_enabled: bool,
pub telemetry_me_level: MeTelemetryLevel,
pub me_socks_kdf_policy: MeSocksKdfPolicy,
pub me_route_backpressure_enabled: bool,
pub me_route_fairshare_enabled: bool,
pub me_floor_mode: MeFloorMode,
pub me_adaptive_floor_idle_secs: u64,
pub me_adaptive_floor_min_writers_single_endpoint: u8,
@@ -187,6 +189,8 @@ impl HotFields {
telemetry_user_enabled: cfg.general.telemetry.user_enabled,
telemetry_me_level: cfg.general.telemetry.me_level,
me_socks_kdf_policy: cfg.general.me_socks_kdf_policy,
me_route_backpressure_enabled: cfg.general.me_route_backpressure_enabled,
me_route_fairshare_enabled: cfg.general.me_route_fairshare_enabled,
me_floor_mode: cfg.general.me_floor_mode,
me_adaptive_floor_idle_secs: cfg.general.me_adaptive_floor_idle_secs,
me_adaptive_floor_min_writers_single_endpoint: cfg
@@ -529,6 +533,8 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
new.general.me_route_backpressure_high_timeout_ms;
cfg.general.me_route_backpressure_high_watermark_pct =
new.general.me_route_backpressure_high_watermark_pct;
cfg.general.me_route_backpressure_enabled = new.general.me_route_backpressure_enabled;
cfg.general.me_route_fairshare_enabled = new.general.me_route_fairshare_enabled;
cfg.general.me_reader_route_data_wait_ms = new.general.me_reader_route_data_wait_ms;
cfg.general.me_d2c_flush_batch_max_frames = new.general.me_d2c_flush_batch_max_frames;
cfg.general.me_d2c_flush_batch_max_bytes = new.general.me_d2c_flush_batch_max_bytes;
@@ -618,6 +624,7 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
|| old.censorship.server_hello_delay_min_ms != new.censorship.server_hello_delay_min_ms
|| old.censorship.server_hello_delay_max_ms != new.censorship.server_hello_delay_max_ms
|| old.censorship.tls_new_session_tickets != new.censorship.tls_new_session_tickets
|| old.censorship.serverhello_compact != new.censorship.serverhello_compact
|| old.censorship.tls_full_cert_ttl_secs != new.censorship.tls_full_cert_ttl_secs
|| old.censorship.alpn_enforce != new.censorship.alpn_enforce
|| old.censorship.mask_proxy_protocol != new.censorship.mask_proxy_protocol
@@ -1053,6 +1060,8 @@ fn log_changes(
!= new_hot.me_route_backpressure_high_timeout_ms
|| old_hot.me_route_backpressure_high_watermark_pct
!= new_hot.me_route_backpressure_high_watermark_pct
|| old_hot.me_route_backpressure_enabled != new_hot.me_route_backpressure_enabled
|| old_hot.me_route_fairshare_enabled != new_hot.me_route_fairshare_enabled
|| old_hot.me_reader_route_data_wait_ms != new_hot.me_reader_route_data_wait_ms
|| old_hot.me_health_interval_ms_unhealthy != new_hot.me_health_interval_ms_unhealthy
|| old_hot.me_health_interval_ms_healthy != new_hot.me_health_interval_ms_healthy
@@ -1060,10 +1069,12 @@ fn log_changes(
|| old_hot.me_warn_rate_limit_ms != new_hot.me_warn_rate_limit_ms
{
info!(
"config reload: me_route_backpressure: base={}ms high={}ms watermark={}%; me_reader_route_data_wait_ms={}; me_health_interval: unhealthy={}ms healthy={}ms; me_admission_poll={}ms; me_warn_rate_limit={}ms",
"config reload: me_route_backpressure: enabled={} base={}ms high={}ms watermark={}%; me_route_fairshare_enabled={}; me_reader_route_data_wait_ms={}; me_health_interval: unhealthy={}ms healthy={}ms; me_admission_poll={}ms; me_warn_rate_limit={}ms",
new_hot.me_route_backpressure_enabled,
new_hot.me_route_backpressure_base_timeout_ms,
new_hot.me_route_backpressure_high_timeout_ms,
new_hot.me_route_backpressure_high_watermark_pct,
new_hot.me_route_fairshare_enabled,
new_hot.me_reader_route_data_wait_ms,
new_hot.me_health_interval_ms_unhealthy,
new_hot.me_health_interval_ms_healthy,

View File

@@ -640,12 +640,6 @@ impl ProxyConfig {
));
}
if config.censorship.mask_relay_max_bytes == 0 {
return Err(ProxyError::Config(
"censorship.mask_relay_max_bytes must be > 0".to_string(),
));
}
if config.censorship.mask_relay_max_bytes > 67_108_864 {
return Err(ProxyError::Config(
"censorship.mask_relay_max_bytes must be <= 67108864".to_string(),

View File

@@ -238,7 +238,7 @@ mask_shape_above_cap_blur_max_bytes = 8
}
#[test]
fn load_rejects_zero_mask_relay_max_bytes() {
fn load_accepts_zero_mask_relay_max_bytes_as_unlimited() {
let path = write_temp_config(
r#"
[censorship]
@@ -246,12 +246,9 @@ mask_relay_max_bytes = 0
"#,
);
let err = ProxyConfig::load(&path).expect_err("mask_relay_max_bytes must be > 0");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_relay_max_bytes must be > 0"),
"error must explain non-zero relay cap invariant, got: {msg}"
);
let cfg = ProxyConfig::load(&path)
.expect("mask_relay_max_bytes=0 must be accepted as unlimited relay cap");
assert_eq!(cfg.censorship.mask_relay_max_bytes, 0);
remove_temp_config(&path);
}

View File

@@ -729,6 +729,14 @@ pub struct GeneralConfig {
#[serde(default)]
pub me_socks_kdf_policy: MeSocksKdfPolicy,
/// Enable route-level ME backpressure controls in reader fairness path.
#[serde(default = "default_me_route_backpressure_enabled")]
pub me_route_backpressure_enabled: bool,
/// Enable worker-local fairshare scheduler for ME reader routing.
#[serde(default = "default_me_route_fairshare_enabled")]
pub me_route_fairshare_enabled: bool,
/// Base backpressure timeout in milliseconds for ME route channel send.
#[serde(default = "default_me_route_backpressure_base_timeout_ms")]
pub me_route_backpressure_base_timeout_ms: u64,
@@ -1059,6 +1067,8 @@ impl Default for GeneralConfig {
disable_colors: false,
telemetry: TelemetryConfig::default(),
me_socks_kdf_policy: MeSocksKdfPolicy::Strict,
me_route_backpressure_enabled: default_me_route_backpressure_enabled(),
me_route_fairshare_enabled: default_me_route_fairshare_enabled(),
me_route_backpressure_base_timeout_ms: default_me_route_backpressure_base_timeout_ms(),
me_route_backpressure_high_timeout_ms: default_me_route_backpressure_high_timeout_ms(),
me_route_backpressure_high_watermark_pct:
@@ -1713,9 +1723,16 @@ pub struct AntiCensorshipConfig {
#[serde(default = "default_tls_new_session_tickets")]
pub tls_new_session_tickets: u8,
/// Enable compact ServerHello payload mode.
/// When false, FakeTLS always uses full ServerHello payload behavior.
/// When true, compact certificate payload mode can be used by TTL policy.
#[serde(default = "default_serverhello_compact")]
pub serverhello_compact: bool,
/// TTL in seconds for sending full certificate payload per client IP.
/// First client connection per (SNI domain, client IP) gets full cert payload.
/// Subsequent handshakes within TTL use compact cert metadata payload.
/// Applied only when `serverhello_compact` is enabled.
#[serde(default = "default_tls_full_cert_ttl_secs")]
pub tls_full_cert_ttl_secs: u64,
@@ -1758,6 +1775,7 @@ pub struct AntiCensorshipConfig {
pub mask_shape_above_cap_blur_max_bytes: usize,
/// Maximum bytes relayed per direction on unauthenticated masking fallback paths.
/// Set to 0 to disable byte cap (unlimited within relay/idle timeouts).
#[serde(default = "default_mask_relay_max_bytes")]
pub mask_relay_max_bytes: usize,
@@ -1809,6 +1827,7 @@ impl Default for AntiCensorshipConfig {
server_hello_delay_min_ms: default_server_hello_delay_min_ms(),
server_hello_delay_max_ms: default_server_hello_delay_max_ms(),
tls_new_session_tickets: default_tls_new_session_tickets(),
serverhello_compact: default_serverhello_compact(),
tls_full_cert_ttl_secs: default_tls_full_cert_ttl_secs(),
alpn_enforce: default_alpn_enforce(),
mask_proxy_protocol: 0,

View File

@@ -22,7 +22,7 @@ pub struct UserIpTracker {
limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>,
limit_window: Arc<RwLock<Duration>>,
last_compact_epoch_secs: Arc<AtomicU64>,
cleanup_queue: Arc<Mutex<Vec<(String, IpAddr)>>>,
cleanup_queue: Arc<Mutex<HashMap<(String, IpAddr), usize>>>,
cleanup_drain_lock: Arc<AsyncMutex<()>>,
}
@@ -45,17 +45,21 @@ impl UserIpTracker {
limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)),
limit_window: Arc::new(RwLock::new(Duration::from_secs(30))),
last_compact_epoch_secs: Arc::new(AtomicU64::new(0)),
cleanup_queue: Arc::new(Mutex::new(Vec::new())),
cleanup_queue: Arc::new(Mutex::new(HashMap::new())),
cleanup_drain_lock: Arc::new(AsyncMutex::new(())),
}
}
pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) {
match self.cleanup_queue.lock() {
Ok(mut queue) => queue.push((user, ip)),
Ok(mut queue) => {
let count = queue.entry((user, ip)).or_insert(0);
*count = count.saturating_add(1);
}
Err(poisoned) => {
let mut queue = poisoned.into_inner();
queue.push((user.clone(), ip));
let count = queue.entry((user.clone(), ip)).or_insert(0);
*count = count.saturating_add(1);
self.cleanup_queue.clear_poison();
tracing::warn!(
"UserIpTracker cleanup_queue lock poisoned; recovered and enqueued IP cleanup for {} ({})",
@@ -75,7 +79,9 @@ impl UserIpTracker {
}
#[cfg(test)]
pub(crate) fn cleanup_queue_mutex_for_tests(&self) -> Arc<Mutex<Vec<(String, IpAddr)>>> {
pub(crate) fn cleanup_queue_mutex_for_tests(
&self,
) -> Arc<Mutex<HashMap<(String, IpAddr), usize>>> {
Arc::clone(&self.cleanup_queue)
}
@@ -105,11 +111,14 @@ impl UserIpTracker {
};
let mut active_ips = self.active_ips.write().await;
for (user, ip) in to_remove {
for ((user, ip), pending_count) in to_remove {
if pending_count == 0 {
continue;
}
if let Some(user_ips) = active_ips.get_mut(&user) {
if let Some(count) = user_ips.get_mut(&ip) {
if *count > 1 {
*count -= 1;
if *count > pending_count {
*count -= pending_count;
} else {
user_ips.remove(&ip);
}

View File

@@ -231,7 +231,11 @@ fn print_help() {
#[cfg(test)]
mod tests {
use super::resolve_runtime_config_path;
use super::{
expected_handshake_close_description, is_expected_handshake_eof, peer_close_description,
resolve_runtime_config_path,
};
use crate::error::{ProxyError, StreamError};
#[test]
fn resolve_runtime_config_path_anchors_relative_to_startup_cwd() {
@@ -299,6 +303,81 @@ mod tests {
let _ = std::fs::remove_dir(&startup_cwd);
}
#[test]
fn expected_handshake_eof_matches_connection_reset() {
let err = ProxyError::Io(std::io::Error::from(std::io::ErrorKind::ConnectionReset));
assert!(is_expected_handshake_eof(&err));
}
#[test]
fn expected_handshake_eof_matches_stream_io_unexpected_eof() {
let err = ProxyError::Stream(StreamError::Io(std::io::Error::from(
std::io::ErrorKind::UnexpectedEof,
)));
assert!(is_expected_handshake_eof(&err));
}
#[test]
fn peer_close_description_is_human_readable_for_all_peer_close_kinds() {
let cases = [
(
std::io::ErrorKind::ConnectionReset,
"Peer reset TCP connection (RST)",
),
(
std::io::ErrorKind::ConnectionAborted,
"Peer aborted TCP connection during transport",
),
(
std::io::ErrorKind::BrokenPipe,
"Peer closed write side (broken pipe)",
),
(
std::io::ErrorKind::NotConnected,
"Socket was already closed by peer",
),
];
for (kind, expected) in cases {
let err = ProxyError::Io(std::io::Error::from(kind));
assert_eq!(peer_close_description(&err), Some(expected));
}
}
#[test]
fn handshake_close_description_is_human_readable_for_all_expected_kinds() {
let cases = [
(
ProxyError::Io(std::io::Error::from(std::io::ErrorKind::UnexpectedEof)),
"Peer closed before sending full 64-byte MTProto handshake",
),
(
ProxyError::Io(std::io::Error::from(std::io::ErrorKind::ConnectionReset)),
"Peer reset TCP connection during initial MTProto handshake",
),
(
ProxyError::Io(std::io::Error::from(std::io::ErrorKind::ConnectionAborted)),
"Peer aborted TCP connection during initial MTProto handshake",
),
(
ProxyError::Io(std::io::Error::from(std::io::ErrorKind::BrokenPipe)),
"Peer closed write side before MTProto handshake completed",
),
(
ProxyError::Io(std::io::Error::from(std::io::ErrorKind::NotConnected)),
"Handshake socket was already closed by peer",
),
(
ProxyError::Stream(StreamError::UnexpectedEof),
"Peer closed before sending full 64-byte MTProto handshake",
),
];
for (err, expected) in cases {
assert_eq!(expected_handshake_close_description(&err), Some(expected));
}
}
}
pub(crate) fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) {
@@ -428,7 +507,63 @@ pub(crate) async fn wait_until_admission_open(admission_rx: &mut watch::Receiver
}
pub(crate) fn is_expected_handshake_eof(err: &crate::error::ProxyError) -> bool {
err.to_string().contains("expected 64 bytes, got 0")
expected_handshake_close_description(err).is_some()
}
pub(crate) fn peer_close_description(err: &crate::error::ProxyError) -> Option<&'static str> {
fn from_kind(kind: std::io::ErrorKind) -> Option<&'static str> {
match kind {
std::io::ErrorKind::ConnectionReset => Some("Peer reset TCP connection (RST)"),
std::io::ErrorKind::ConnectionAborted => {
Some("Peer aborted TCP connection during transport")
}
std::io::ErrorKind::BrokenPipe => Some("Peer closed write side (broken pipe)"),
std::io::ErrorKind::NotConnected => Some("Socket was already closed by peer"),
_ => None,
}
}
match err {
crate::error::ProxyError::Io(ioe) => from_kind(ioe.kind()),
crate::error::ProxyError::Stream(crate::error::StreamError::Io(ioe)) => {
from_kind(ioe.kind())
}
_ => None,
}
}
pub(crate) fn expected_handshake_close_description(
err: &crate::error::ProxyError,
) -> Option<&'static str> {
fn from_kind(kind: std::io::ErrorKind) -> Option<&'static str> {
match kind {
std::io::ErrorKind::UnexpectedEof => {
Some("Peer closed before sending full 64-byte MTProto handshake")
}
std::io::ErrorKind::ConnectionReset => {
Some("Peer reset TCP connection during initial MTProto handshake")
}
std::io::ErrorKind::ConnectionAborted => {
Some("Peer aborted TCP connection during initial MTProto handshake")
}
std::io::ErrorKind::BrokenPipe => {
Some("Peer closed write side before MTProto handshake completed")
}
std::io::ErrorKind::NotConnected => Some("Handshake socket was already closed by peer"),
_ => None,
}
}
match err {
crate::error::ProxyError::Io(ioe) => from_kind(ioe.kind()),
crate::error::ProxyError::Stream(crate::error::StreamError::UnexpectedEof) => {
Some("Peer closed before sending full 64-byte MTProto handshake")
}
crate::error::ProxyError::Stream(crate::error::StreamError::Io(ioe)) => {
from_kind(ioe.kind())
}
_ => None,
}
}
pub(crate) async fn load_startup_proxy_config_snapshot(

View File

@@ -24,7 +24,10 @@ use crate::transport::middle_proxy::MePool;
use crate::transport::socket::set_linger_zero;
use crate::transport::{ListenOptions, UpstreamManager, create_listener, find_listener_processes};
use super::helpers::{is_expected_handshake_eof, print_proxy_links};
use super::helpers::{
expected_handshake_close_description, is_expected_handshake_eof, peer_close_description,
print_proxy_links,
};
pub(crate) struct BoundListeners {
pub(crate) listeners: Vec<(TcpListener, bool)>,
@@ -485,29 +488,9 @@ pub(crate) fn spawn_tcp_accept_loops(
Ok(guard) => *guard,
Err(_) => None,
};
let peer_closed = matches!(
&e,
crate::error::ProxyError::Io(ioe)
if matches!(
ioe.kind(),
std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::BrokenPipe
| std::io::ErrorKind::NotConnected
)
) || matches!(
&e,
crate::error::ProxyError::Stream(
crate::error::StreamError::Io(ioe)
)
if matches!(
ioe.kind(),
std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::BrokenPipe
| std::io::ErrorKind::NotConnected
)
);
let peer_close_reason = peer_close_description(&e);
let handshake_close_reason =
expected_handshake_close_description(&e);
let me_closed = matches!(
&e,
@@ -518,12 +501,23 @@ pub(crate) fn spawn_tcp_accept_loops(
crate::error::ProxyError::Proxy(msg) if msg == ROUTE_SWITCH_ERROR_MSG
);
match (peer_closed, me_closed) {
(true, _) => {
match (peer_close_reason, me_closed) {
(Some(reason), _) => {
if let Some(real_peer) = real_peer {
debug!(peer = %peer_addr, real_peer = %real_peer, error = %e, "Connection closed by client");
debug!(
peer = %peer_addr,
real_peer = %real_peer,
error = %e,
close_reason = reason,
"Connection closed by peer"
);
} else {
debug!(peer = %peer_addr, error = %e, "Connection closed by client");
debug!(
peer = %peer_addr,
error = %e,
close_reason = reason,
"Connection closed by peer"
);
}
}
(_, true) => {
@@ -541,10 +535,23 @@ pub(crate) fn spawn_tcp_accept_loops(
}
}
_ if is_expected_handshake_eof(&e) => {
let reason = handshake_close_reason
.unwrap_or("Peer closed during initial handshake");
if let Some(real_peer) = real_peer {
info!(peer = %peer_addr, real_peer = %real_peer, error = %e, "Connection closed during initial handshake");
info!(
peer = %peer_addr,
real_peer = %real_peer,
error = %e,
close_reason = reason,
"Connection closed during initial handshake"
);
} else {
info!(peer = %peer_addr, error = %e, "Connection closed during initial handshake");
info!(
peer = %peer_addr,
error = %e,
close_reason = reason,
"Connection closed during initial handshake"
);
}
}
_ => {

View File

@@ -277,6 +277,8 @@ pub(crate) async fn initialize_me_pool(
config.general.me_socks_kdf_policy,
config.general.me_writer_cmd_channel_capacity,
config.general.me_route_channel_capacity,
config.general.me_route_backpressure_enabled,
config.general.me_route_fairshare_enabled,
config.general.me_route_backpressure_base_timeout_ms,
config.general.me_route_backpressure_high_timeout_ms,
config.general.me_route_backpressure_high_watermark_pct,

View File

@@ -122,6 +122,8 @@ pub(crate) async fn spawn_runtime_tasks(
if let Some(pool) = &me_pool_for_policy {
pool.update_runtime_transport_policy(
cfg.general.me_socks_kdf_policy,
cfg.general.me_route_backpressure_enabled,
cfg.general.me_route_fairshare_enabled,
cfg.general.me_route_backpressure_base_timeout_ms,
cfg.general.me_route_backpressure_high_timeout_ms,
cfg.general.me_route_backpressure_high_watermark_pct,

View File

@@ -1383,6 +1383,8 @@ fn emulated_server_hello_never_places_alpn_in_server_hello_extensions() {
&session_id,
&cached,
false,
true,
ClientHelloTlsVersion::Tls13,
&rng,
Some(b"h2".to_vec()),
0,
@@ -1624,6 +1626,34 @@ fn test_extract_alpn_multiple() {
assert_eq!(alpn_str, vec!["h2", "spdy", "h3"]);
}
#[test]
fn detect_client_hello_tls_version_prefers_supported_versions_tls13() {
let supported_versions = vec![4, 0x03, 0x04, 0x03, 0x03];
let ch = build_client_hello_with_exts(vec![(0x002b, supported_versions)], "example.com");
assert_eq!(
detect_client_hello_tls_version(&ch),
Some(ClientHelloTlsVersion::Tls13)
);
}
#[test]
fn detect_client_hello_tls_version_falls_back_to_legacy_tls12() {
let ch = build_client_hello_with_exts(Vec::new(), "example.com");
assert_eq!(
detect_client_hello_tls_version(&ch),
Some(ClientHelloTlsVersion::Tls12)
);
}
#[test]
fn detect_client_hello_tls_version_rejects_malformed_supported_versions() {
// list_len=3 is invalid because version vector must contain u16 pairs.
let malformed_supported_versions = vec![3, 0x03, 0x04, 0x03];
let ch =
build_client_hello_with_exts(vec![(0x002b, malformed_supported_versions)], "example.com");
assert!(detect_client_hello_tls_version(&ch).is_none());
}
#[test]
fn extract_sni_rejects_zero_length_host_name() {
let mut sni_ext = Vec::new();

View File

@@ -811,6 +811,122 @@ pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> {
out
}
/// ClientHello TLS generation inferred from handshake fields.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ClientHelloTlsVersion {
Tls12,
Tls13,
}
/// Detect TLS generation from a ClientHello.
///
/// The parser prefers `supported_versions` (0x002b) when present and falls back
/// to `legacy_version` for compatibility with TLS 1.2 style hellos.
pub fn detect_client_hello_tls_version(handshake: &[u8]) -> Option<ClientHelloTlsVersion> {
if handshake.len() < 5 || handshake[0] != TLS_RECORD_HANDSHAKE {
return None;
}
let record_len = u16::from_be_bytes([handshake[3], handshake[4]]) as usize;
if handshake.len() < 5 + record_len {
return None;
}
let mut pos = 5; // after record header
if handshake.get(pos) != Some(&0x01) {
return None; // not ClientHello
}
pos += 1; // message type
if pos + 3 > handshake.len() {
return None;
}
let handshake_len = ((handshake[pos] as usize) << 16)
| ((handshake[pos + 1] as usize) << 8)
| handshake[pos + 2] as usize;
pos += 3; // handshake length bytes
if pos + handshake_len > 5 + record_len {
return None;
}
if pos + 2 + 32 > handshake.len() {
return None;
}
let legacy_version = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]);
pos += 2 + 32; // version + random
let session_id_len = *handshake.get(pos)? as usize;
pos += 1 + session_id_len;
if pos + 2 > handshake.len() {
return None;
}
let cipher_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
pos += 2 + cipher_len;
if pos >= handshake.len() {
return None;
}
let comp_len = *handshake.get(pos)? as usize;
pos += 1 + comp_len;
if pos + 2 > handshake.len() {
return None;
}
let ext_len = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]) as usize;
pos += 2;
let ext_end = pos + ext_len;
if ext_end > handshake.len() {
return None;
}
while pos + 4 <= ext_end {
let etype = u16::from_be_bytes([handshake[pos], handshake[pos + 1]]);
let elen = u16::from_be_bytes([handshake[pos + 2], handshake[pos + 3]]) as usize;
pos += 4;
if pos + elen > ext_end {
return None;
}
if etype == extension_type::SUPPORTED_VERSIONS {
if elen < 1 {
return None;
}
let list_len = handshake[pos] as usize;
if list_len == 0 || list_len % 2 != 0 || 1 + list_len > elen {
return None;
}
let mut has_tls12 = false;
let mut ver_pos = pos + 1;
let ver_end = ver_pos + list_len;
while ver_pos + 1 < ver_end {
let version = u16::from_be_bytes([handshake[ver_pos], handshake[ver_pos + 1]]);
if version == 0x0304 {
return Some(ClientHelloTlsVersion::Tls13);
}
if version == 0x0303 || version == 0x0302 || version == 0x0301 {
has_tls12 = true;
}
ver_pos += 2;
}
if has_tls12 {
return Some(ClientHelloTlsVersion::Tls12);
}
return None;
}
pos += elen;
}
if legacy_version >= 0x0303 {
Some(ClientHelloTlsVersion::Tls12)
} else {
None
}
}
/// Check if bytes look like a TLS ClientHello
pub fn is_tls_handshake(first_bytes: &[u8]) -> bool {
if first_bytes.len() < 3 {

View File

@@ -31,16 +31,24 @@ struct UserConnectionReservation {
ip_tracker: Arc<UserIpTracker>,
user: String,
ip: IpAddr,
tracks_ip: bool,
active: bool,
}
impl UserConnectionReservation {
fn new(stats: Arc<Stats>, ip_tracker: Arc<UserIpTracker>, user: String, ip: IpAddr) -> Self {
fn new(
stats: Arc<Stats>,
ip_tracker: Arc<UserIpTracker>,
user: String,
ip: IpAddr,
tracks_ip: bool,
) -> Self {
Self {
stats,
ip_tracker,
user,
ip,
tracks_ip,
active: true,
}
}
@@ -49,7 +57,9 @@ impl UserConnectionReservation {
if !self.active {
return;
}
self.ip_tracker.remove_ip(&self.user, self.ip).await;
if self.tracks_ip {
self.ip_tracker.remove_ip(&self.user, self.ip).await;
}
self.active = false;
self.stats.decrement_user_curr_connects(&self.user);
}
@@ -62,7 +72,9 @@ impl Drop for UserConnectionReservation {
}
self.active = false;
self.stats.decrement_user_curr_connects(&self.user);
self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip);
if self.tracks_ip {
self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip);
}
}
}
@@ -324,17 +336,38 @@ fn record_beobachten_class(
beobachten.record(class, peer_ip, beobachten_ttl(config));
}
fn classify_expected_64_got_0(kind: std::io::ErrorKind) -> Option<&'static str> {
match kind {
std::io::ErrorKind::UnexpectedEof => Some("expected_64_got_0_unexpected_eof"),
std::io::ErrorKind::ConnectionReset => Some("expected_64_got_0_connection_reset"),
std::io::ErrorKind::ConnectionAborted => Some("expected_64_got_0_connection_aborted"),
std::io::ErrorKind::BrokenPipe => Some("expected_64_got_0_broken_pipe"),
std::io::ErrorKind::NotConnected => Some("expected_64_got_0_not_connected"),
_ => None,
}
}
fn classify_handshake_failure_class(error: &ProxyError) -> &'static str {
match error {
ProxyError::Io(err) => classify_expected_64_got_0(err.kind()).unwrap_or("other"),
ProxyError::Stream(StreamError::UnexpectedEof) => "expected_64_got_0_unexpected_eof",
ProxyError::Stream(StreamError::Io(err)) => {
classify_expected_64_got_0(err.kind()).unwrap_or("other")
}
_ => "other",
}
}
fn record_handshake_failure_class(
beobachten: &BeobachtenStore,
config: &ProxyConfig,
peer_ip: IpAddr,
error: &ProxyError,
) {
let class = match error {
ProxyError::Io(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {
"expected_64_got_0"
}
ProxyError::Stream(StreamError::UnexpectedEof) => "expected_64_got_0",
// Keep beobachten buckets stable while detailed per-kind classification
// is tracked in API counters.
let class = match classify_handshake_failure_class(error) {
value if value.starts_with("expected_64_got_0_") => "expected_64_got_0",
_ => "other",
};
record_beobachten_class(beobachten, config, peer_ip, class);
@@ -343,7 +376,7 @@ fn record_handshake_failure_class(
#[inline]
fn increment_bad_on_unknown_tls_sni(stats: &Stats, error: &ProxyError) {
if matches!(error, ProxyError::UnknownTlsSni) {
stats.increment_connects_bad();
stats.increment_connects_bad_with_class("unknown_tls_sni");
}
}
@@ -444,7 +477,7 @@ where
Ok(Ok(info)) => {
if !is_trusted_proxy_source(peer.ip(), &config.server.proxy_protocol_trusted_cidrs)
{
stats.increment_connects_bad();
stats.increment_connects_bad_with_class("proxy_protocol_untrusted");
warn!(
peer = %peer,
trusted = ?config.server.proxy_protocol_trusted_cidrs,
@@ -465,13 +498,13 @@ where
}
}
Ok(Err(e)) => {
stats.increment_connects_bad();
stats.increment_connects_bad_with_class("proxy_protocol_invalid_header");
warn!(peer = %peer, error = %e, "Invalid PROXY protocol header");
record_beobachten_class(&beobachten, &config, peer.ip(), "other");
return Err(e);
}
Err(_) => {
stats.increment_connects_bad();
stats.increment_connects_bad_with_class("proxy_protocol_header_timeout");
warn!(peer = %peer, timeout_ms = proxy_header_timeout.as_millis(), "PROXY protocol header timeout");
record_beobachten_class(&beobachten, &config, peer.ip(), "other");
return Err(ProxyError::InvalidProxyProtocol);
@@ -561,7 +594,7 @@ where
// third-party clients or future Telegram versions.
if !tls_clienthello_len_in_bounds(tls_len) {
debug!(peer = %real_peer, tls_len = tls_len, max_tls_len = MAX_TLS_PLAINTEXT_SIZE, "TLS handshake length out of bounds");
stats.increment_connects_bad();
stats.increment_connects_bad_with_class("tls_clienthello_len_out_of_bounds");
maybe_apply_mask_reject_delay(&config).await;
let (reader, writer) = tokio::io::split(stream);
return Ok(masking_outcome(
@@ -581,7 +614,7 @@ where
Ok(n) => n,
Err(e) => {
debug!(peer = %real_peer, error = %e, tls_len = tls_len, "TLS ClientHello body read failed; engaging masking fallback");
stats.increment_connects_bad();
stats.increment_connects_bad_with_class("tls_clienthello_read_error");
maybe_apply_mask_reject_delay(&config).await;
let initial_len = 5;
let (reader, writer) = tokio::io::split(stream);
@@ -599,7 +632,7 @@ where
if body_read < tls_len {
debug!(peer = %real_peer, got = body_read, expected = tls_len, "Truncated in-range TLS ClientHello; engaging masking fallback");
stats.increment_connects_bad();
stats.increment_connects_bad_with_class("tls_clienthello_truncated");
maybe_apply_mask_reject_delay(&config).await;
let initial_len = 5 + body_read;
let (reader, writer) = tokio::io::split(stream);
@@ -623,7 +656,7 @@ where
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
stats.increment_connects_bad_with_class("tls_handshake_bad_client");
return Ok(masking_outcome(
reader,
writer,
@@ -663,7 +696,7 @@ where
wrap_tls_application_record(&pending_plaintext)
};
let reader = tokio::io::AsyncReadExt::chain(std::io::Cursor::new(pending_record), reader);
stats.increment_connects_bad();
stats.increment_connects_bad_with_class("tls_mtproto_bad_client");
debug!(
peer = %peer,
"Authenticated TLS session failed MTProto validation; engaging masking fallback"
@@ -693,7 +726,7 @@ where
} else {
if !config.general.modes.classic && !config.general.modes.secure {
debug!(peer = %real_peer, "Non-TLS modes disabled");
stats.increment_connects_bad();
stats.increment_connects_bad_with_class("direct_modes_disabled");
maybe_apply_mask_reject_delay(&config).await;
let (reader, writer) = tokio::io::split(stream);
return Ok(masking_outcome(
@@ -720,7 +753,7 @@ where
).await {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
stats.increment_connects_bad_with_class("direct_mtproto_bad_client");
return Ok(masking_outcome(
reader,
writer,
@@ -757,6 +790,7 @@ where
Ok(Ok(outcome)) => outcome,
Ok(Err(e)) => {
debug!(peer = %peer, error = %e, "Handshake failed");
stats_for_timeout.increment_handshake_failure_class(classify_handshake_failure_class(&e));
record_handshake_failure_class(
&beobachten_for_timeout,
&config_for_timeout,
@@ -767,6 +801,7 @@ where
}
Err(_) => {
stats_for_timeout.increment_handshake_timeouts();
stats_for_timeout.increment_handshake_failure_class("timeout");
debug!(peer = %peer, "Handshake timeout");
record_beobachten_class(
&beobachten_for_timeout,
@@ -956,7 +991,8 @@ impl RunningClientHandler {
self.peer.ip(),
&self.config.server.proxy_protocol_trusted_cidrs,
) {
self.stats.increment_connects_bad();
self.stats
.increment_connects_bad_with_class("proxy_protocol_untrusted");
warn!(
peer = %self.peer,
trusted = ?self.config.server.proxy_protocol_trusted_cidrs,
@@ -986,7 +1022,8 @@ impl RunningClientHandler {
}
}
Ok(Err(e)) => {
self.stats.increment_connects_bad();
self.stats
.increment_connects_bad_with_class("proxy_protocol_invalid_header");
warn!(peer = %self.peer, error = %e, "Invalid PROXY protocol header");
record_beobachten_class(
&self.beobachten,
@@ -997,7 +1034,8 @@ impl RunningClientHandler {
return Err(e);
}
Err(_) => {
self.stats.increment_connects_bad();
self.stats
.increment_connects_bad_with_class("proxy_protocol_header_timeout");
warn!(
peer = %self.peer,
timeout_ms = proxy_header_timeout.as_millis(),
@@ -1095,6 +1133,7 @@ impl RunningClientHandler {
Ok(Ok(outcome)) => outcome,
Ok(Err(e)) => {
debug!(peer = %peer_for_log, error = %e, "Handshake failed");
stats.increment_handshake_failure_class(classify_handshake_failure_class(&e));
record_handshake_failure_class(
&beobachten_for_timeout,
&config_for_timeout,
@@ -1105,6 +1144,7 @@ impl RunningClientHandler {
}
Err(_) => {
stats.increment_handshake_timeouts();
stats.increment_handshake_failure_class("timeout");
debug!(peer = %peer_for_log, "Handshake timeout");
record_beobachten_class(
&beobachten_for_timeout,
@@ -1140,7 +1180,8 @@ impl RunningClientHandler {
// third-party clients or future Telegram versions.
if !tls_clienthello_len_in_bounds(tls_len) {
debug!(peer = %peer, tls_len = tls_len, max_tls_len = MAX_TLS_PLAINTEXT_SIZE, "TLS handshake length out of bounds");
self.stats.increment_connects_bad();
self.stats
.increment_connects_bad_with_class("tls_clienthello_len_out_of_bounds");
maybe_apply_mask_reject_delay(&self.config).await;
let (reader, writer) = self.stream.into_split();
return Ok(masking_outcome(
@@ -1160,7 +1201,8 @@ impl RunningClientHandler {
Ok(n) => n,
Err(e) => {
debug!(peer = %peer, error = %e, tls_len = tls_len, "TLS ClientHello body read failed; engaging masking fallback");
self.stats.increment_connects_bad();
self.stats
.increment_connects_bad_with_class("tls_clienthello_read_error");
maybe_apply_mask_reject_delay(&self.config).await;
let (reader, writer) = self.stream.into_split();
return Ok(masking_outcome(
@@ -1177,7 +1219,8 @@ impl RunningClientHandler {
if body_read < tls_len {
debug!(peer = %peer, got = body_read, expected = tls_len, "Truncated in-range TLS ClientHello; engaging masking fallback");
self.stats.increment_connects_bad();
self.stats
.increment_connects_bad_with_class("tls_clienthello_truncated");
maybe_apply_mask_reject_delay(&self.config).await;
let initial_len = 5 + body_read;
let (reader, writer) = self.stream.into_split();
@@ -1214,7 +1257,7 @@ impl RunningClientHandler {
{
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
stats.increment_connects_bad_with_class("tls_handshake_bad_client");
return Ok(masking_outcome(
reader,
writer,
@@ -1264,7 +1307,7 @@ impl RunningClientHandler {
};
let reader =
tokio::io::AsyncReadExt::chain(std::io::Cursor::new(pending_record), reader);
stats.increment_connects_bad();
stats.increment_connects_bad_with_class("tls_mtproto_bad_client");
debug!(
peer = %peer,
"Authenticated TLS session failed MTProto validation; engaging masking fallback"
@@ -1311,7 +1354,8 @@ impl RunningClientHandler {
if !self.config.general.modes.classic && !self.config.general.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled");
self.stats.increment_connects_bad();
self.stats
.increment_connects_bad_with_class("direct_modes_disabled");
maybe_apply_mask_reject_delay(&self.config).await;
let (reader, writer) = self.stream.into_split();
return Ok(masking_outcome(
@@ -1351,7 +1395,7 @@ impl RunningClientHandler {
{
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
stats.increment_connects_bad_with_class("direct_mtproto_bad_client");
return Ok(masking_outcome(
reader,
writer,
@@ -1568,19 +1612,22 @@ impl RunningClientHandler {
});
}
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
Ok(()) => {}
Err(reason) => {
stats.decrement_user_curr_connects(user);
warn!(
user = %user,
ip = %peer_addr.ip(),
reason = %reason,
"IP limit exceeded"
);
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
let tracks_ip = ip_tracker.get_user_limit(user).await.is_some();
if tracks_ip {
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
Ok(()) => {}
Err(reason) => {
stats.decrement_user_curr_connects(user);
warn!(
user = %user,
ip = %peer_addr.ip(),
reason = %reason,
"IP limit exceeded"
);
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
}
}
@@ -1589,6 +1636,7 @@ impl RunningClientHandler {
ip_tracker,
user.to_string(),
peer_addr.ip(),
tracks_ip,
))
}
@@ -1631,25 +1679,27 @@ impl RunningClientHandler {
});
}
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
Ok(()) => {
ip_tracker.remove_ip(user, peer_addr.ip()).await;
stats.decrement_user_curr_connects(user);
}
Err(reason) => {
stats.decrement_user_curr_connects(user);
warn!(
user = %user,
ip = %peer_addr.ip(),
reason = %reason,
"IP limit exceeded"
);
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
if ip_tracker.get_user_limit(user).await.is_some() {
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
Ok(()) => {
ip_tracker.remove_ip(user, peer_addr.ip()).await;
}
Err(reason) => {
stats.decrement_user_curr_connects(user);
warn!(
user = %user,
ip = %peer_addr.ip(),
reason = %reason,
"IP limit exceeded"
);
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
}
}
stats.decrement_user_curr_connects(user);
Ok(())
}
}

View File

@@ -55,6 +55,7 @@ const STICKY_HINT_MAX_ENTRIES: usize = 65_536;
const CANDIDATE_HINT_TRACK_CAP: usize = 64;
const OVERLOAD_CANDIDATE_BUDGET_HINTED: usize = 16;
const OVERLOAD_CANDIDATE_BUDGET_UNHINTED: usize = 8;
const EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD: usize = 64;
const RECENT_USER_RING_SCAN_LIMIT: usize = 32;
type HmacSha256 = Hmac<Sha256>;
@@ -551,6 +552,19 @@ fn auth_probe_note_saturation_in(shared: &ProxySharedState, now: Instant) {
}
}
fn auth_probe_note_expensive_invalid_scan_in(
shared: &ProxySharedState,
now: Instant,
validation_checks: usize,
overload: bool,
) {
if overload || validation_checks < EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD {
return;
}
auth_probe_note_saturation_in(shared, now);
}
fn auth_probe_record_failure_in(shared: &ProxySharedState, peer_ip: IpAddr, now: Instant) {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = &shared.handshake.auth_probe;
@@ -1119,6 +1133,10 @@ where
} else {
None
};
// Fail-closed to TLS 1.3 semantics when ClientHello version is ambiguous:
// this avoids leaking certificate payload on malformed probes.
let client_tls_version = tls::detect_client_hello_tls_version(handshake)
.unwrap_or(tls::ClientHelloTlsVersion::Tls13);
if client_sni.is_some() && matched_tls_domain.is_none() && preferred_user_hint.is_none() {
let sni = client_sni.as_deref().unwrap_or_default();
@@ -1374,7 +1392,14 @@ where
}
if !matched {
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
let failure_now = Instant::now();
auth_probe_note_expensive_invalid_scan_in(
shared,
failure_now,
validation_checks,
overload,
);
auth_probe_record_failure_in(shared, peer.ip(), failure_now);
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,
@@ -1439,12 +1464,18 @@ where
let selected_domain =
matched_tls_domain.unwrap_or(config.censorship.tls_domain.as_str());
let cached_entry = cache.get(selected_domain).await;
let use_full_cert_payload = cache
.take_full_cert_budget_for_ip(
peer.ip(),
Duration::from_secs(config.censorship.tls_full_cert_ttl_secs),
)
.await;
let use_full_cert_payload = if config.censorship.serverhello_compact
&& matches!(client_tls_version, tls::ClientHelloTlsVersion::Tls12)
{
cache
.take_full_cert_budget_for_ip(
peer.ip(),
Duration::from_secs(config.censorship.tls_full_cert_ttl_secs),
)
.await
} else {
true
};
Some((cached_entry, use_full_cert_payload))
} else {
None
@@ -1465,6 +1496,8 @@ where
validation_session_id_slice,
&cached_entry,
use_full_cert_payload,
config.censorship.serverhello_compact,
client_tls_version,
rng,
selected_alpn.clone(),
config.censorship.tls_new_session_tickets,
@@ -1741,7 +1774,14 @@ where
}
if !matched {
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
let failure_now = Instant::now();
auth_probe_note_expensive_invalid_scan_in(
shared,
failure_now,
validation_checks,
overload,
);
auth_probe_record_failure_in(shared, peer.ip(), failure_now);
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,

View File

@@ -60,21 +60,18 @@ where
let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]);
let mut total = 0usize;
let mut ended_by_eof = false;
if byte_cap == 0 {
return CopyOutcome {
total,
ended_by_eof,
};
}
let unlimited = byte_cap == 0;
loop {
let remaining_budget = byte_cap.saturating_sub(total);
if remaining_budget == 0 {
break;
}
let read_len = remaining_budget.min(MASK_BUFFER_SIZE);
let read_len = if unlimited {
MASK_BUFFER_SIZE
} else {
let remaining_budget = byte_cap.saturating_sub(total);
if remaining_budget == 0 {
break;
}
remaining_budget.min(MASK_BUFFER_SIZE)
};
let read_res = timeout(idle_timeout, reader.read(&mut buf[..read_len])).await;
let n = match read_res {
Ok(Ok(n)) => n,
@@ -930,21 +927,21 @@ async fn consume_client_data<R: AsyncRead + Unpin>(
byte_cap: usize,
idle_timeout: Duration,
) {
if byte_cap == 0 {
return;
}
// Keep drain path fail-closed under slow-loris stalls.
let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]);
let mut total = 0usize;
let unlimited = byte_cap == 0;
loop {
let remaining_budget = byte_cap.saturating_sub(total);
if remaining_budget == 0 {
break;
}
let read_len = remaining_budget.min(MASK_BUFFER_SIZE);
let read_len = if unlimited {
MASK_BUFFER_SIZE
} else {
let remaining_budget = byte_cap.saturating_sub(total);
if remaining_budget == 0 {
break;
}
remaining_budget.min(MASK_BUFFER_SIZE)
};
let n = match timeout(idle_timeout, reader.read(&mut buf[..read_len])).await {
Ok(Ok(n)) => n,
Ok(Err(_)) | Err(_) => break,
@@ -955,7 +952,7 @@ async fn consume_client_data<R: AsyncRead + Unpin>(
}
total = total.saturating_add(n);
if total >= byte_cap {
if !unlimited && total >= byte_cap {
break;
}
}

View File

@@ -12,7 +12,7 @@ use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot, watch};
use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot, watch};
use tokio::time::timeout;
use tracing::{debug, info, trace, warn};
@@ -36,7 +36,11 @@ use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
enum C2MeCommand {
Data { payload: PooledBuffer, flags: u32 },
Data {
payload: PooledBuffer,
flags: u32,
_permit: OwnedSemaphorePermit,
},
Close,
}
@@ -47,6 +51,8 @@ const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
const C2ME_QUEUED_BYTE_PERMIT_UNIT: usize = 16 * 1024;
const C2ME_QUEUED_PERMITS_PER_SLOT: usize = 4;
const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1);
const TINY_FRAME_DEBT_PER_TINY: u32 = 8;
const TINY_FRAME_DEBT_LIMIT: u32 = 512;
@@ -571,6 +577,43 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool
has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET
}
fn c2me_payload_permits(payload_len: usize) -> u32 {
payload_len
.max(1)
.div_ceil(C2ME_QUEUED_BYTE_PERMIT_UNIT)
.min(u32::MAX as usize) as u32
}
fn c2me_queued_permit_budget(channel_capacity: usize, frame_limit: usize) -> usize {
channel_capacity
.saturating_mul(C2ME_QUEUED_PERMITS_PER_SLOT)
.max(c2me_payload_permits(frame_limit) as usize)
.max(1)
}
async fn acquire_c2me_payload_permit(
semaphore: &Arc<Semaphore>,
payload_len: usize,
send_timeout: Option<Duration>,
stats: &Stats,
) -> Result<OwnedSemaphorePermit> {
let permits = c2me_payload_permits(payload_len);
let acquire = semaphore.clone().acquire_many_owned(permits);
match send_timeout {
Some(send_timeout) => match timeout(send_timeout, acquire).await {
Ok(Ok(permit)) => Ok(permit),
Ok(Err(_)) => Err(ProxyError::Proxy("ME sender byte budget closed".into())),
Err(_) => {
stats.increment_me_c2me_send_timeout_total();
Err(ProxyError::Proxy("ME sender byte budget timeout".into()))
}
},
None => acquire
.await
.map_err(|_| ProxyError::Proxy("ME sender byte budget closed".into())),
}
}
fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 {
limit.saturating_add(overshoot)
}
@@ -1122,13 +1165,19 @@ where
0 => None,
timeout_ms => Some(Duration::from_millis(timeout_ms)),
};
let c2me_byte_budget = c2me_queued_permit_budget(c2me_channel_capacity, frame_limit);
let c2me_byte_semaphore = Arc::new(Semaphore::new(c2me_byte_budget));
let (c2me_tx, mut c2me_rx) = mpsc::channel::<C2MeCommand>(c2me_channel_capacity);
let me_pool_c2me = me_pool.clone();
let c2me_sender = tokio::spawn(async move {
let mut sent_since_yield = 0usize;
while let Some(cmd) = c2me_rx.recv().await {
match cmd {
C2MeCommand::Data { payload, flags } => {
C2MeCommand::Data {
payload,
flags,
_permit,
} => {
me_pool_c2me
.send_proxy_req(
conn_id,
@@ -1624,11 +1673,29 @@ where
if payload.len() >= 8 && payload[..8].iter().all(|b| *b == 0) {
flags |= RPC_FLAG_NOT_ENCRYPTED;
}
let payload_permit = match acquire_c2me_payload_permit(
&c2me_byte_semaphore,
payload.len(),
c2me_send_timeout,
stats.as_ref(),
)
.await
{
Ok(permit) => permit,
Err(e) => {
main_result = Err(e);
break;
}
};
// Keep client read loop lightweight: route heavy ME send path via a dedicated task.
if enqueue_c2me_command_in(
shared.as_ref(),
&c2me_tx,
C2MeCommand::Data { payload, flags },
C2MeCommand::Data {
payload,
flags,
_permit: payload_permit,
},
c2me_send_timeout,
stats.as_ref(),
)
@@ -2201,6 +2268,7 @@ enum MeWriterResponseOutcome {
Close,
}
#[cfg(test)]
async fn process_me_writer_response<W>(
response: MeResponse,
client_writer: &mut CryptoWriter<W>,
@@ -2261,7 +2329,7 @@ where
W: AsyncWrite + Unpin + Send + 'static,
{
match response {
MeResponse::Data { flags, data } => {
MeResponse::Data { flags, data, .. } => {
if batched {
trace!(conn_id, bytes = data.len(), flags, "ME->C data (batched)");
} else {

View File

@@ -230,6 +230,7 @@ struct RateWaitState {
}
impl<S> StatsIo<S> {
#[cfg(test)]
fn new(
inner: S,
counters: Arc<SharedCounters>,

View File

@@ -282,7 +282,7 @@ async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() {
assert_eq!(stats.get_user_curr_connects(&user), 1);
let reservation =
UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip);
UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip, true);
// Drop the reservation synchronously without any tokio::spawn/await yielding!
drop(reservation);
@@ -320,6 +320,7 @@ async fn relay_task_abort_releases_user_gate_and_ip_reservation() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 8).await;
let mut cfg = ProxyConfig::default();
cfg.access.user_max_tcp_conns.insert(user.to_string(), 8);
@@ -437,6 +438,7 @@ async fn relay_cutover_releases_user_gate_and_ip_reservation() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 8).await;
let mut cfg = ProxyConfig::default();
cfg.access.user_max_tcp_conns.insert(user.to_string(), 8);
@@ -2493,6 +2495,46 @@ fn unexpected_eof_is_classified_without_string_matching() {
);
}
#[test]
fn connection_reset_is_classified_as_expected_handshake_close() {
let beobachten = BeobachtenStore::new();
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 1;
let reset = ProxyError::Io(std::io::Error::from(std::io::ErrorKind::ConnectionReset));
let peer_ip: IpAddr = "198.51.100.202".parse().unwrap();
record_handshake_failure_class(&beobachten, &config, peer_ip, &reset);
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(
snapshot.contains("[expected_64_got_0]"),
"ConnectionReset must be classified as expected handshake close"
);
}
#[test]
fn stream_io_unexpected_eof_is_classified_without_string_matching() {
let beobachten = BeobachtenStore::new();
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 1;
let eof = ProxyError::Stream(StreamError::Io(std::io::Error::from(
std::io::ErrorKind::UnexpectedEof,
)));
let peer_ip: IpAddr = "198.51.100.203".parse().unwrap();
record_handshake_failure_class(&beobachten, &config, peer_ip, &eof);
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(
snapshot.contains("[expected_64_got_0]"),
"StreamError::Io(UnexpectedEof) must be classified as expected handshake close"
);
}
#[test]
fn non_eof_error_is_classified_as_other() {
let beobachten = BeobachtenStore::new();
@@ -2839,6 +2881,7 @@ async fn explicit_reservation_release_cleans_user_and_ip_immediately() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 4).await;
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user,
@@ -2877,6 +2920,7 @@ async fn explicit_reservation_release_does_not_double_decrement_on_drop() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 4).await;
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user,
@@ -2907,6 +2951,7 @@ async fn drop_fallback_eventually_cleans_user_and_ip_reservation() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user,
@@ -2989,6 +3034,7 @@ async fn release_abort_storm_does_not_leak_user_or_ip_reservations() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, ATTEMPTS + 16).await;
for idx in 0..ATTEMPTS {
let peer = SocketAddr::new(
@@ -3039,6 +3085,7 @@ async fn release_abort_loop_preserves_immediate_same_ip_reacquire() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
for _ in 0..ITERATIONS {
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
@@ -3097,6 +3144,7 @@ async fn adversarial_mixed_release_drop_abort_wave_converges_to_zero() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, RESERVATIONS + 8).await;
let mut reservations = Vec::with_capacity(RESERVATIONS);
for idx in 0..RESERVATIONS {
@@ -3177,6 +3225,8 @@ async fn parallel_users_abort_release_isolation_preserves_independent_cleanup()
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user_a, 64).await;
ip_tracker.set_user_limit(user_b, 64).await;
let mut tasks = tokio::task::JoinSet::new();
for idx in 0..64usize {
@@ -3238,6 +3288,7 @@ async fn concurrent_release_storm_leaves_zero_user_and_ip_footprint() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, RESERVATIONS + 8).await;
let mut reservations = Vec::with_capacity(RESERVATIONS);
for idx in 0..RESERVATIONS {
@@ -3292,6 +3343,7 @@ async fn relay_connect_error_releases_user_and_ip_before_return() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 8).await;
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 1);
@@ -3387,6 +3439,7 @@ async fn mixed_release_and_drop_same_ip_preserves_counter_correctness() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static(
user,
@@ -3447,6 +3500,7 @@ async fn drop_one_of_two_same_ip_reservations_keeps_ip_active() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static(
user,
@@ -3656,6 +3710,7 @@ async fn cross_thread_drop_uses_captured_runtime_for_ip_cleanup() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 8).await;
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user,
@@ -3700,6 +3755,7 @@ async fn immediate_reacquire_after_cross_thread_drop_succeeds() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user,

View File

@@ -1252,6 +1252,97 @@ async fn tls_overload_budget_limits_candidate_scan_depth() {
);
}
#[tokio::test]
async fn tls_expensive_invalid_scan_activates_saturation_budget() {
let mut config = ProxyConfig::default();
config.access.users.clear();
config.access.ignore_time_skew = true;
for idx in 0..80u8 {
config.access.users.insert(
format!("user-{idx}"),
format!("{:032x}", u128::from(idx) + 1),
);
}
config.rebuild_runtime_user_auth().unwrap();
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let shared = ProxySharedState::new();
let attacker_secret = [0xEFu8; 16];
let handshake = make_valid_tls_handshake(&attacker_secret, 0);
let first_peer: SocketAddr = "198.51.100.214:44326".parse().unwrap();
let first = handle_tls_handshake_with_shared(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
first_peer,
&config,
&replay_checker,
&rng,
None,
shared.as_ref(),
)
.await;
assert!(matches!(first, HandshakeResult::BadClient { .. }));
assert!(
auth_probe_saturation_state_for_testing_in_shared(shared.as_ref())
.lock()
.unwrap()
.is_some(),
"expensive invalid scan must activate global saturation"
);
assert_eq!(
shared
.handshake
.auth_expensive_checks_total
.load(Ordering::Relaxed),
80,
"first invalid probe preserves full first-hit compatibility before enabling saturation"
);
{
let mut saturation = auth_probe_saturation_state_for_testing_in_shared(shared.as_ref())
.lock()
.unwrap();
let state = saturation.as_mut().expect("saturation must be present");
state.blocked_until = Instant::now() + Duration::from_millis(200);
}
let second_peer: SocketAddr = "198.51.100.215:44326".parse().unwrap();
let second = handle_tls_handshake_with_shared(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
second_peer,
&config,
&replay_checker,
&rng,
None,
shared.as_ref(),
)
.await;
assert!(matches!(second, HandshakeResult::BadClient { .. }));
assert_eq!(
shared
.handshake
.auth_budget_exhausted_total
.load(Ordering::Relaxed),
1,
"second invalid probe must be capped by overload budget"
);
assert_eq!(
shared
.handshake
.auth_expensive_checks_total
.load(Ordering::Relaxed),
80 + OVERLOAD_CANDIDATE_BUDGET_UNHINTED as u64,
"saturation budget must bound follow-up invalid scans"
);
}
#[tokio::test]
async fn mtproto_runtime_snapshot_prefers_preferred_user_hint() {
let mut config = ProxyConfig::default();

View File

@@ -58,11 +58,22 @@ async fn consume_stall_stress_finishes_within_idle_budget() {
}
#[tokio::test]
async fn consume_zero_cap_returns_immediately() {
async fn consume_zero_cap_is_idle_bounded_on_stall() {
let started = Instant::now();
consume_client_data(tokio::io::empty(), 0, MASK_RELAY_IDLE_TIMEOUT).await;
tokio::time::timeout(
MASK_RELAY_TIMEOUT,
consume_client_data(OneByteThenStall { sent: false }, 0, MASK_RELAY_IDLE_TIMEOUT),
)
.await
.expect("zero-cap consume path must remain bounded by timeout guards");
let elapsed = started.elapsed();
assert!(
started.elapsed() < MASK_RELAY_IDLE_TIMEOUT,
"zero byte cap must return immediately"
elapsed >= (MASK_RELAY_IDLE_TIMEOUT / 2),
"zero cap must not short-circuit before idle timeout path, got {elapsed:?}"
);
assert!(
elapsed < MASK_RELAY_TIMEOUT,
"zero-cap consume path must complete before relay timeout, got {elapsed:?}"
);
}

View File

@@ -148,9 +148,10 @@ async fn positive_copy_with_production_cap_stops_exactly_at_budget() {
}
#[tokio::test]
async fn negative_consume_with_zero_cap_performs_no_reads() {
let read_calls = Arc::new(AtomicUsize::new(0));
let reader = FinitePatternReader::new(1024, 64, Arc::clone(&read_calls));
async fn consume_with_zero_cap_drains_until_eof() {
let payload = 256 * 1024;
let total_read = Arc::new(AtomicUsize::new(0));
let reader = BudgetProbeReader::new(payload, Arc::clone(&total_read));
consume_client_data_with_timeout_and_cap(
reader,
@@ -161,9 +162,27 @@ async fn negative_consume_with_zero_cap_performs_no_reads() {
.await;
assert_eq!(
read_calls.load(Ordering::Relaxed),
0,
"zero cap must return before reading attacker-controlled bytes"
total_read.load(Ordering::Relaxed),
payload,
"zero cap must disable byte budget and drain finite payload to EOF"
);
}
#[tokio::test]
async fn copy_with_zero_cap_drains_until_eof() {
let read_calls = Arc::new(AtomicUsize::new(0));
let payload = 73 * 1024;
let mut reader = FinitePatternReader::new(payload, 3072, read_calls);
let mut writer = CountingWriter::default();
let outcome =
copy_with_idle_timeout(&mut reader, &mut writer, 0, true, MASK_RELAY_IDLE_TIMEOUT).await;
assert_eq!(outcome.total, payload);
assert_eq!(writer.written, payload);
assert!(
outcome.ended_by_eof,
"zero cap must not terminate relay early on byte budget"
);
}

View File

@@ -70,6 +70,7 @@ async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() {
MeResponse::Data {
flags: 0,
data: payload.clone(),
route_permit: None,
},
&mut writer,
ProtoTag::Intermediate,
@@ -139,6 +140,7 @@ async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() {
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xAA, 0xBB, 0xCC]),
route_permit: None,
},
&mut writer,
ProtoTag::Intermediate,

View File

@@ -12,6 +12,12 @@ fn make_pooled_payload(data: &[u8]) -> PooledBuffer {
payload
}
fn make_c2me_permit() -> tokio::sync::OwnedSemaphorePermit {
Arc::new(tokio::sync::Semaphore::new(1))
.try_acquire_many_owned(1)
.expect("test permit must be available")
}
#[test]
#[ignore = "Tracking for M-04: Verify should_emit_full_desync returns true on first occurrence and false on duplicate within window"]
fn should_emit_full_desync_filters_duplicates() {
@@ -107,6 +113,7 @@ async fn c2me_channel_full_path_yields_then_sends() {
tx.send(C2MeCommand::Data {
payload: make_pooled_payload(&[0xAA]),
flags: 1,
_permit: make_c2me_permit(),
})
.await
.expect("priming queue with one frame must succeed");
@@ -119,6 +126,7 @@ async fn c2me_channel_full_path_yields_then_sends() {
C2MeCommand::Data {
payload: make_pooled_payload(&[0xBB, 0xCC]),
flags: 2,
_permit: make_c2me_permit(),
},
None,
&stats,
@@ -138,7 +146,7 @@ async fn c2me_channel_full_path_yields_then_sends() {
.expect("receiver should observe primed frame")
.expect("first queued command must exist");
match first {
C2MeCommand::Data { payload, flags } => {
C2MeCommand::Data { payload, flags, .. } => {
assert_eq!(payload.as_ref(), &[0xAA]);
assert_eq!(flags, 1);
}
@@ -155,7 +163,7 @@ async fn c2me_channel_full_path_yields_then_sends() {
.expect("receiver should observe backpressure-resumed frame")
.expect("second queued command must exist");
match second {
C2MeCommand::Data { payload, flags } => {
C2MeCommand::Data { payload, flags, .. } => {
assert_eq!(payload.as_ref(), &[0xBB, 0xCC]);
assert_eq!(flags, 2);
}

View File

@@ -7,6 +7,7 @@ use std::time::{Duration, Instant};
use parking_lot::Mutex;
const CLEANUP_INTERVAL: Duration = Duration::from_secs(30);
const MAX_BEOBACHTEN_ENTRIES: usize = 65_536;
#[derive(Default)]
struct BeobachtenInner {
@@ -48,12 +49,23 @@ impl BeobachtenStore {
Self::cleanup_if_needed(&mut guard, now, ttl);
let key = (class.to_string(), ip);
let entry = guard.entries.entry(key).or_insert(BeobachtenEntry {
tries: 0,
last_seen: now,
});
entry.tries = entry.tries.saturating_add(1);
entry.last_seen = now;
if let Some(entry) = guard.entries.get_mut(&key) {
entry.tries = entry.tries.saturating_add(1);
entry.last_seen = now;
return;
}
if guard.entries.len() >= MAX_BEOBACHTEN_ENTRIES {
return;
}
guard.entries.insert(
key,
BeobachtenEntry {
tries: 1,
last_seen: now,
},
);
}
pub fn snapshot_text(&self, ttl: Duration) -> String {

View File

@@ -88,6 +88,8 @@ impl Drop for RouteConnectionLease {
pub struct Stats {
connects_all: AtomicU64,
connects_bad: AtomicU64,
connects_bad_classes: DashMap<&'static str, AtomicU64>,
handshake_failure_classes: DashMap<&'static str, AtomicU64>,
current_connections_direct: AtomicU64,
current_connections_me: AtomicU64,
handshake_timeouts: AtomicU64,
@@ -518,10 +520,32 @@ impl Stats {
self.connects_all.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_connects_bad(&self) {
if self.telemetry_core_enabled() {
self.connects_bad.fetch_add(1, Ordering::Relaxed);
pub fn increment_connects_bad_with_class(&self, class: &'static str) {
if !self.telemetry_core_enabled() {
return;
}
self.connects_bad.fetch_add(1, Ordering::Relaxed);
let entry = self
.connects_bad_classes
.entry(class)
.or_insert_with(|| AtomicU64::new(0));
entry.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_connects_bad(&self) {
self.increment_connects_bad_with_class("other");
}
pub fn increment_handshake_failure_class(&self, class: &'static str) {
if !self.telemetry_core_enabled() {
return;
}
let entry = self
.handshake_failure_classes
.entry(class)
.or_insert_with(|| AtomicU64::new(0));
entry.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_current_connections_direct(&self) {
self.current_connections_direct
@@ -1640,6 +1664,37 @@ impl Stats {
pub fn get_connects_bad(&self) -> u64 {
self.connects_bad.load(Ordering::Relaxed)
}
pub fn get_connects_bad_class_counts(&self) -> Vec<(String, u64)> {
let mut out: Vec<(String, u64)> = self
.connects_bad_classes
.iter()
.map(|entry| {
(
entry.key().to_string(),
entry.value().load(Ordering::Relaxed),
)
})
.collect();
out.sort_by(|a, b| a.0.cmp(&b.0));
out
}
pub fn get_handshake_failure_class_counts(&self) -> Vec<(String, u64)> {
let mut out: Vec<(String, u64)> = self
.handshake_failure_classes
.iter()
.map(|entry| {
(
entry.key().to_string(),
entry.value().load(Ordering::Relaxed),
)
})
.collect();
out.sort_by(|a, b| a.0.cmp(&b.0));
out
}
pub fn get_accept_permit_timeout_total(&self) -> u64 {
self.accept_permit_timeout_total.load(Ordering::Relaxed)
}

View File

@@ -649,6 +649,25 @@ async fn duplicate_cleanup_entries_do_not_break_future_admission() {
);
}
#[tokio::test]
async fn duplicate_cleanup_entries_are_coalesced_until_drain() {
let tracker = UserIpTracker::new();
let ip = ip_from_idx(7150);
tracker.enqueue_cleanup("coalesced-cleanup".to_string(), ip);
tracker.enqueue_cleanup("coalesced-cleanup".to_string(), ip);
tracker.enqueue_cleanup("coalesced-cleanup".to_string(), ip);
assert_eq!(
tracker.cleanup_queue_len_for_tests(),
1,
"duplicate queued cleanup entries must retain one allocation slot"
);
tracker.drain_cleanup_queue().await;
assert_eq!(tracker.cleanup_queue_len_for_tests(), 0);
}
#[tokio::test]
async fn stress_repeated_queue_poison_recovery_preserves_admission_progress() {
let tracker = UserIpTracker::new();

View File

@@ -5,7 +5,9 @@ use crate::protocol::constants::{
MAX_TLS_CIPHERTEXT_SIZE, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER,
TLS_RECORD_HANDSHAKE, TLS_VERSION,
};
use crate::protocol::tls::{TLS_DIGEST_LEN, TLS_DIGEST_POS, gen_fake_x25519_key};
use crate::protocol::tls::{
ClientHelloTlsVersion, TLS_DIGEST_LEN, TLS_DIGEST_POS, gen_fake_x25519_key,
};
use crate::tls_front::types::{CachedTlsData, ParsedCertificateInfo, TlsProfileSource};
use crc32fast::Hasher;
@@ -190,6 +192,8 @@ pub fn build_emulated_server_hello(
session_id: &[u8],
cached: &CachedTlsData,
use_full_cert_payload: bool,
serverhello_compact: bool,
client_tls_version: ClientHelloTlsVersion,
rng: &SecureRandom,
alpn: Option<Vec<u8>>,
new_session_tickets: u8,
@@ -265,20 +269,33 @@ pub fn build_emulated_server_hello(
}
}
};
let compact_payload = cached
.cert_info
.as_ref()
.and_then(build_compact_cert_info_payload)
.and_then(hash_compact_cert_info_payload);
let selected_payload: Option<&[u8]> = if use_full_cert_payload {
let compact_payload = if serverhello_compact {
cached
.cert_payload
.cert_info
.as_ref()
.map(|payload| payload.certificate_message.as_slice())
.filter(|payload| !payload.is_empty())
.or(compact_payload.as_deref())
.and_then(build_compact_cert_info_payload)
.and_then(hash_compact_cert_info_payload)
} else {
compact_payload.as_deref()
None
};
let full_payload = cached
.cert_payload
.as_ref()
.map(|payload| payload.certificate_message.as_slice())
.filter(|payload| !payload.is_empty());
let selected_payload: Option<&[u8]> = match client_tls_version {
ClientHelloTlsVersion::Tls13 => None,
ClientHelloTlsVersion::Tls12 => {
if serverhello_compact {
if use_full_cert_payload {
full_payload.or(compact_payload.as_deref())
} else {
compact_payload.as_deref()
}
} else {
full_payload
}
}
};
if let Some(payload) = selected_payload {
@@ -402,6 +419,7 @@ mod tests {
use crate::protocol::constants::{
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE,
};
use crate::protocol::tls::ClientHelloTlsVersion;
fn first_app_data_payload(response: &[u8]) -> &[u8] {
let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize;
@@ -448,6 +466,8 @@ mod tests {
&[0x22; 16],
&cached,
true,
true,
ClientHelloTlsVersion::Tls12,
&rng,
None,
0,
@@ -474,6 +494,8 @@ mod tests {
&[0x33; 16],
&cached,
true,
true,
ClientHelloTlsVersion::Tls12,
&rng,
None,
0,
@@ -506,6 +528,8 @@ mod tests {
&[0x55; 16],
&cached,
false,
true,
ClientHelloTlsVersion::Tls12,
&rng,
None,
0,
@@ -529,6 +553,68 @@ mod tests {
);
}
#[test]
fn test_build_emulated_server_hello_tls13_never_uses_cert_payload() {
let cert_msg = vec![0x0b, 0x00, 0x00, 0x05, 0x00, 0xaa, 0xbb, 0xcc, 0xdd];
let cached = make_cached(Some(TlsCertPayload {
cert_chain_der: vec![vec![0x30, 0x01, 0x00]],
certificate_message: cert_msg.clone(),
}));
let rng = SecureRandom::new();
let response = build_emulated_server_hello(
b"secret",
&[0x56; 32],
&[0x78; 16],
&cached,
true,
true,
ClientHelloTlsVersion::Tls13,
&rng,
None,
0,
);
let payload = first_app_data_payload(&response);
assert!(
!payload.starts_with(&cert_msg),
"TLS 1.3 response path must not expose certificate payload bytes"
);
}
#[test]
fn test_build_emulated_server_hello_compact_disabled_skips_compact_payload() {
let mut cached = make_cached(None);
cached.cert_info = Some(crate::tls_front::types::ParsedCertificateInfo {
not_after_unix: Some(1_900_000_000),
not_before_unix: Some(1_700_000_000),
issuer_cn: Some("Issuer".to_string()),
subject_cn: Some("example.com".to_string()),
san_names: vec!["example.com".to_string()],
});
let rng = SecureRandom::new();
let response = build_emulated_server_hello(
b"secret",
&[0x90; 32],
&[0x91; 16],
&cached,
false,
false,
ClientHelloTlsVersion::Tls12,
&rng,
Some(b"h2".to_vec()),
0,
);
let payload = first_app_data_payload(&response);
let expected_alpn_marker = [0x00u8, 0x10, 0x00, 0x05, 0x00, 0x03, 0x02, b'h', b'2'];
assert!(
payload.starts_with(&expected_alpn_marker),
"when compact mode is disabled and no full cert payload exists, the random/alpn path must be used"
);
}
#[test]
fn test_build_emulated_server_hello_ignores_tail_records_for_profiled_tls() {
let mut cached = make_cached(None);
@@ -545,6 +631,8 @@ mod tests {
&[0x34; 16],
&cached,
false,
true,
ClientHelloTlsVersion::Tls13,
&rng,
None,
0,

View File

@@ -20,6 +20,7 @@ use rustls::client::ClientConfig;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{DigitallySignedStruct, Error as RustlsError};
use x25519_dalek::{X25519_BASEPOINT_BYTES, x25519};
use x509_parser::certificate::X509Certificate;
use x509_parser::prelude::FromDer;
@@ -275,7 +276,7 @@ fn remember_profile_success(
);
}
fn build_client_config() -> Arc<ClientConfig> {
fn build_client_config(alpn_protocols: &[&[u8]]) -> Arc<ClientConfig> {
let root = rustls::RootCertStore::empty();
let provider = rustls::crypto::ring::default_provider();
@@ -288,6 +289,7 @@ fn build_client_config() -> Arc<ClientConfig> {
config
.dangerous()
.set_certificate_verifier(Arc::new(NoVerify));
config.alpn_protocols = alpn_protocols.iter().map(|proto| proto.to_vec()).collect();
Arc::new(config)
}
@@ -359,6 +361,22 @@ fn profile_alpn(profile: TlsFetchProfile) -> &'static [&'static [u8]] {
}
}
fn profile_alpn_labels(profile: TlsFetchProfile) -> &'static [&'static str] {
const H2_HTTP11: &[&str] = &["h2", "http/1.1"];
const HTTP11: &[&str] = &["http/1.1"];
match profile {
TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => H2_HTTP11,
TlsFetchProfile::CompatTls12 | TlsFetchProfile::LegacyMinimal => HTTP11,
}
}
fn profile_session_id_len(profile: TlsFetchProfile) -> usize {
match profile {
TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => 32,
TlsFetchProfile::CompatTls12 | TlsFetchProfile::LegacyMinimal => 0,
}
}
fn profile_supported_versions(profile: TlsFetchProfile) -> &'static [u16] {
const MODERN: &[u16] = &[0x0304, 0x0303];
const COMPAT: &[u16] = &[0x0303, 0x0304];
@@ -413,8 +431,20 @@ fn build_client_hello(
body.extend_from_slice(&rng.bytes(32));
}
// Session ID: empty
body.push(0);
// Use non-empty Session ID for modern TLS 1.3-like profiles to reduce middlebox friction.
let session_id_len = profile_session_id_len(profile);
let session_id = if session_id_len == 0 {
Vec::new()
} else if deterministic {
deterministic_bytes(
&format!("tls-fetch-session:{sni}:{}", profile.as_str()),
session_id_len,
)
} else {
rng.bytes(session_id_len)
};
body.push(session_id.len() as u8);
body.extend_from_slice(&session_id);
let mut cipher_suites = profile_cipher_suites(profile).to_vec();
if grease_enabled {
@@ -433,16 +463,26 @@ fn build_client_hello(
// === Extensions ===
let mut exts = Vec::new();
let mut push_extension = |ext_type: u16, data: &[u8]| {
exts.extend_from_slice(&ext_type.to_be_bytes());
exts.extend_from_slice(&(data.len() as u16).to_be_bytes());
exts.extend_from_slice(data);
};
// server_name (SNI)
let sni_bytes = sni.as_bytes();
let mut sni_ext = Vec::with_capacity(5 + sni_bytes.len());
sni_ext.extend_from_slice(&(sni_bytes.len() as u16 + 3).to_be_bytes());
sni_ext.push(0); // host_name
sni_ext.push(0);
sni_ext.extend_from_slice(&(sni_bytes.len() as u16).to_be_bytes());
sni_ext.extend_from_slice(sni_bytes);
exts.extend_from_slice(&0x0000u16.to_be_bytes());
exts.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes());
exts.extend_from_slice(&sni_ext);
push_extension(0x0000, &sni_ext);
// Chrome-like profile keeps browser-like ordering and extension set.
if matches!(profile, TlsFetchProfile::ModernChromeLike) {
// ec_point_formats: uncompressed only.
push_extension(0x000b, &[0x01, 0x00]);
}
// supported_groups
let mut groups = profile_groups(profile).to_vec();
@@ -450,11 +490,16 @@ fn build_client_hello(
let grease = grease_value(rng, deterministic, &format!("group:{sni}"));
groups.insert(0, grease);
}
exts.extend_from_slice(&0x000au16.to_be_bytes());
exts.extend_from_slice(&((2 + groups.len() * 2) as u16).to_be_bytes());
exts.extend_from_slice(&(groups.len() as u16 * 2).to_be_bytes());
let mut groups_ext = Vec::with_capacity(2 + groups.len() * 2);
groups_ext.extend_from_slice(&(groups.len() as u16 * 2).to_be_bytes());
for g in groups {
exts.extend_from_slice(&g.to_be_bytes());
groups_ext.extend_from_slice(&g.to_be_bytes());
}
push_extension(0x000a, &groups_ext);
if matches!(profile, TlsFetchProfile::ModernChromeLike) {
// session_ticket
push_extension(0x0023, &[]);
}
// signature_algorithms
@@ -463,12 +508,12 @@ fn build_client_hello(
let grease = grease_value(rng, deterministic, &format!("sigalg:{sni}"));
sig_algs.insert(0, grease);
}
exts.extend_from_slice(&0x000du16.to_be_bytes());
exts.extend_from_slice(&((2 + sig_algs.len() * 2) as u16).to_be_bytes());
exts.extend_from_slice(&(sig_algs.len() as u16 * 2).to_be_bytes());
let mut sig_algs_ext = Vec::with_capacity(2 + sig_algs.len() * 2);
sig_algs_ext.extend_from_slice(&(sig_algs.len() as u16 * 2).to_be_bytes());
for a in sig_algs {
exts.extend_from_slice(&a.to_be_bytes());
sig_algs_ext.extend_from_slice(&a.to_be_bytes());
}
push_extension(0x000d, &sig_algs_ext);
// supported_versions
let mut versions = profile_supported_versions(profile).to_vec();
@@ -476,30 +521,32 @@ fn build_client_hello(
let grease = grease_value(rng, deterministic, &format!("version:{sni}"));
versions.insert(0, grease);
}
exts.extend_from_slice(&0x002bu16.to_be_bytes());
exts.extend_from_slice(&((1 + versions.len() * 2) as u16).to_be_bytes());
exts.push((versions.len() * 2) as u8);
let mut versions_ext = Vec::with_capacity(1 + versions.len() * 2);
versions_ext.push((versions.len() * 2) as u8);
for v in versions {
exts.extend_from_slice(&v.to_be_bytes());
versions_ext.extend_from_slice(&v.to_be_bytes());
}
push_extension(0x002b, &versions_ext);
if matches!(profile, TlsFetchProfile::ModernChromeLike) {
// psk_key_exchange_modes: psk_dhe_ke
push_extension(0x002d, &[0x01, 0x01]);
}
// key_share (x25519)
let key = if deterministic {
let det = deterministic_bytes(&format!("keyshare:{sni}"), 32);
let mut key = [0u8; 32];
key.copy_from_slice(&det);
key
} else {
gen_key_share(rng)
};
let key = gen_key_share(
rng,
deterministic,
&format!("tls-fetch-keyshare:{sni}:{}", profile.as_str()),
);
let mut keyshare = Vec::with_capacity(4 + key.len());
keyshare.extend_from_slice(&0x001du16.to_be_bytes()); // group
keyshare.extend_from_slice(&0x001du16.to_be_bytes());
keyshare.extend_from_slice(&(key.len() as u16).to_be_bytes());
keyshare.extend_from_slice(&key);
exts.extend_from_slice(&0x0033u16.to_be_bytes());
exts.extend_from_slice(&((2 + keyshare.len()) as u16).to_be_bytes());
exts.extend_from_slice(&(keyshare.len() as u16).to_be_bytes());
exts.extend_from_slice(&keyshare);
let mut keyshare_ext = Vec::with_capacity(2 + keyshare.len());
keyshare_ext.extend_from_slice(&(keyshare.len() as u16).to_be_bytes());
keyshare_ext.extend_from_slice(&keyshare);
push_extension(0x0033, &keyshare_ext);
// ALPN
let mut alpn_list = Vec::new();
@@ -508,16 +555,15 @@ fn build_client_hello(
alpn_list.extend_from_slice(proto);
}
if !alpn_list.is_empty() {
exts.extend_from_slice(&0x0010u16.to_be_bytes());
exts.extend_from_slice(&((2 + alpn_list.len()) as u16).to_be_bytes());
exts.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
exts.extend_from_slice(&alpn_list);
let mut alpn_ext = Vec::with_capacity(2 + alpn_list.len());
alpn_ext.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
alpn_ext.extend_from_slice(&alpn_list);
push_extension(0x0010, &alpn_ext);
}
if grease_enabled {
let grease = grease_value(rng, deterministic, &format!("ext:{sni}"));
exts.extend_from_slice(&grease.to_be_bytes());
exts.extend_from_slice(&0u16.to_be_bytes());
push_extension(grease, &[]);
}
// padding to reduce recognizability and keep length ~500 bytes
@@ -553,10 +599,14 @@ fn build_client_hello(
record
}
fn gen_key_share(rng: &SecureRandom) -> [u8; 32] {
let mut key = [0u8; 32];
key.copy_from_slice(&rng.bytes(32));
key
fn gen_key_share(rng: &SecureRandom, deterministic: bool, seed: &str) -> [u8; 32] {
let mut scalar = [0u8; 32];
if deterministic {
scalar.copy_from_slice(&deterministic_bytes(seed, 32));
} else {
scalar.copy_from_slice(&rng.bytes(32));
}
x25519(scalar, X25519_BASEPOINT_BYTES)
}
async fn read_tls_record<S>(stream: &mut S) -> Result<(u8, Vec<u8>)>
@@ -1018,6 +1068,7 @@ async fn fetch_via_rustls_stream<S>(
host: &str,
sni: &str,
proxy_header: Option<Vec<u8>>,
alpn_protocols: &[&[u8]],
) -> Result<TlsFetchResult>
where
S: AsyncRead + AsyncWrite + Unpin,
@@ -1028,7 +1079,7 @@ where
stream.flush().await?;
}
let config = build_client_config();
let config = build_client_config(alpn_protocols);
let connector = TlsConnector::from(config);
let server_name = ServerName::try_from(sni.to_owned())
@@ -1113,6 +1164,7 @@ async fn fetch_via_rustls(
proxy_protocol: u8,
unix_sock: Option<&str>,
strict_route: bool,
alpn_protocols: &[&[u8]],
) -> Result<TlsFetchResult> {
#[cfg(unix)]
if let Some(sock_path) = unix_sock {
@@ -1124,7 +1176,8 @@ async fn fetch_via_rustls(
"Rustls fetch using mask unix socket"
);
let proxy_header = build_tls_fetch_proxy_header(proxy_protocol, None, None);
return fetch_via_rustls_stream(stream, host, sni, proxy_header).await;
return fetch_via_rustls_stream(stream, host, sni, proxy_header, alpn_protocols)
.await;
}
Ok(Err(e)) => {
warn!(
@@ -1152,7 +1205,7 @@ async fn fetch_via_rustls(
.await?;
let (src_addr, dst_addr) = socket_addrs_from_upstream_stream(&stream);
let proxy_header = build_tls_fetch_proxy_header(proxy_protocol, src_addr, dst_addr);
fetch_via_rustls_stream(stream, host, sni, proxy_header).await
fetch_via_rustls_stream(stream, host, sni, proxy_header, alpn_protocols).await
}
/// Fetch real TLS metadata with an adaptive multi-profile strategy.
@@ -1191,6 +1244,14 @@ pub async fn fetch_real_tls_with_strategy(
break;
}
let timeout_for_attempt = attempt_timeout.min(total_budget - elapsed);
debug!(
sni = %sni,
profile = profile.as_str(),
alpn = ?profile_alpn_labels(profile),
grease_enabled = strategy.grease_enabled,
deterministic = strategy.deterministic,
"TLS fetch ClientHello params (raw)"
);
match fetch_via_raw_tls(
host,
@@ -1256,6 +1317,16 @@ pub async fn fetch_real_tls_with_strategy(
}
let rustls_timeout = attempt_timeout.min(total_budget - elapsed);
let rustls_profile = selected_profile.unwrap_or(TlsFetchProfile::ModernChromeLike);
let rustls_alpn_protocols = profile_alpn(rustls_profile);
debug!(
sni = %sni,
profile = rustls_profile.as_str(),
alpn = ?profile_alpn_labels(rustls_profile),
grease_enabled = strategy.grease_enabled,
deterministic = strategy.deterministic,
"TLS fetch ClientHello params (rustls)"
);
let rustls_result = fetch_via_rustls(
host,
port,
@@ -1266,6 +1337,7 @@ pub async fn fetch_real_tls_with_strategy(
proxy_protocol,
unix_sock,
strategy.strict_route,
rustls_alpn_protocols,
)
.await;
@@ -1327,8 +1399,8 @@ mod tests {
use super::{
ProfileCacheValue, TlsFetchStrategy, build_client_hello, build_tls_fetch_proxy_header,
derive_behavior_profile, encode_tls13_certificate_message, order_profiles, profile_cache,
profile_cache_key,
derive_behavior_profile, encode_tls13_certificate_message, fetch_via_rustls_stream,
order_profiles, profile_alpn, profile_cache, profile_cache_key,
};
use crate::config::TlsFetchProfile;
use crate::crypto::SecureRandom;
@@ -1336,11 +1408,115 @@ mod tests {
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE,
};
use crate::tls_front::types::TlsProfileSource;
use tokio::io::AsyncReadExt;
struct ParsedClientHelloForTest {
session_id: Vec<u8>,
extensions: Vec<(u16, Vec<u8>)>,
}
fn read_u24(bytes: &[u8]) -> usize {
((bytes[0] as usize) << 16) | ((bytes[1] as usize) << 8) | (bytes[2] as usize)
}
fn parse_client_hello_for_test(record: &[u8]) -> ParsedClientHelloForTest {
assert!(record.len() >= 9, "record too short");
assert_eq!(record[0], TLS_RECORD_HANDSHAKE, "not a handshake record");
let record_len = u16::from_be_bytes([record[3], record[4]]) as usize;
assert_eq!(record.len(), 5 + record_len, "record length mismatch");
let handshake = &record[5..];
assert_eq!(handshake[0], 0x01, "not a ClientHello handshake");
let hello_len = read_u24(&handshake[1..4]);
assert_eq!(handshake.len(), 4 + hello_len, "handshake length mismatch");
let hello = &handshake[4..];
let mut pos = 0usize;
pos += 2;
pos += 32;
let session_len = hello[pos] as usize;
pos += 1;
let session_id = hello[pos..pos + session_len].to_vec();
pos += session_len;
let cipher_len = u16::from_be_bytes([hello[pos], hello[pos + 1]]) as usize;
pos += 2 + cipher_len;
let compression_len = hello[pos] as usize;
pos += 1 + compression_len;
let ext_len = u16::from_be_bytes([hello[pos], hello[pos + 1]]) as usize;
pos += 2;
let ext_end = pos + ext_len;
assert_eq!(ext_end, hello.len(), "extensions length mismatch");
let mut extensions = Vec::new();
while pos + 4 <= ext_end {
let ext_type = u16::from_be_bytes([hello[pos], hello[pos + 1]]);
let data_len = u16::from_be_bytes([hello[pos + 2], hello[pos + 3]]) as usize;
pos += 4;
let data = hello[pos..pos + data_len].to_vec();
pos += data_len;
extensions.push((ext_type, data));
}
assert_eq!(pos, ext_end, "extension parse did not consume all bytes");
ParsedClientHelloForTest {
session_id,
extensions,
}
}
fn parse_alpn_protocols(data: &[u8]) -> Vec<Vec<u8>> {
assert!(data.len() >= 2, "ALPN extension is too short");
let protocols_len = u16::from_be_bytes([data[0], data[1]]) as usize;
assert_eq!(protocols_len + 2, data.len(), "ALPN list length mismatch");
let mut pos = 2usize;
let mut out = Vec::new();
while pos < data.len() {
let len = data[pos] as usize;
pos += 1;
out.push(data[pos..pos + len].to_vec());
pos += len;
}
out
}
async fn capture_rustls_client_hello_record(
alpn_protocols: &'static [&'static [u8]],
) -> Vec<u8> {
let (client, mut server) = tokio::io::duplex(32 * 1024);
let fetch_task = tokio::spawn(async move {
fetch_via_rustls_stream(client, "example.com", "example.com", None, alpn_protocols)
.await
});
let mut header = [0u8; 5];
server
.read_exact(&mut header)
.await
.expect("must read client hello record header");
let body_len = u16::from_be_bytes([header[3], header[4]]) as usize;
let mut body = vec![0u8; body_len];
server
.read_exact(&mut body)
.await
.expect("must read client hello record body");
drop(server);
let result = fetch_task.await.expect("fetch task must join");
assert!(
result.is_err(),
"capture task should end with handshake error"
);
let mut record = Vec::with_capacity(5 + body_len);
record.extend_from_slice(&header);
record.extend_from_slice(&body);
record
}
#[test]
fn test_encode_tls13_certificate_message_single_cert() {
let cert = vec![0x30, 0x03, 0x02, 0x01, 0x01];
@@ -1470,6 +1646,186 @@ mod tests {
assert_eq!(first, second);
}
#[test]
fn test_raw_client_hello_alpn_matches_profile() {
let rng = SecureRandom::new();
for profile in [
TlsFetchProfile::ModernChromeLike,
TlsFetchProfile::ModernFirefoxLike,
TlsFetchProfile::CompatTls12,
TlsFetchProfile::LegacyMinimal,
] {
let hello = build_client_hello("alpn.example", &rng, profile, false, true);
let parsed = parse_client_hello_for_test(&hello);
let alpn_ext = parsed
.extensions
.iter()
.find(|(ext_type, _)| *ext_type == 0x0010)
.expect("ALPN extension must exist");
let parsed_alpn = parse_alpn_protocols(&alpn_ext.1);
let expected_alpn = profile_alpn(profile)
.iter()
.map(|proto| proto.to_vec())
.collect::<Vec<_>>();
assert_eq!(
parsed_alpn,
expected_alpn,
"ALPN mismatch for {}",
profile.as_str()
);
}
}
#[test]
fn test_modern_chrome_like_browser_extension_layout() {
let rng = SecureRandom::new();
let hello = build_client_hello(
"chrome.example",
&rng,
TlsFetchProfile::ModernChromeLike,
false,
true,
);
let parsed = parse_client_hello_for_test(&hello);
assert_eq!(
parsed.session_id.len(),
32,
"modern chrome must use non-empty session id"
);
let extension_ids = parsed
.extensions
.iter()
.map(|(ext_type, _)| *ext_type)
.collect::<Vec<_>>();
let expected_prefix = [
0x0000, 0x000b, 0x000a, 0x0023, 0x000d, 0x002b, 0x002d, 0x0033, 0x0010,
];
assert!(
extension_ids.as_slice().starts_with(&expected_prefix),
"unexpected extension order: {extension_ids:?}"
);
assert!(
extension_ids.contains(&0x0015),
"modern chrome profile should include padding extension"
);
let key_share = parsed
.extensions
.iter()
.find(|(ext_type, _)| *ext_type == 0x0033)
.expect("key_share extension must exist");
let key_share_data = &key_share.1;
assert!(
key_share_data.len() >= 2 + 4 + 32,
"key_share payload is too short"
);
let entry_len = u16::from_be_bytes([key_share_data[0], key_share_data[1]]) as usize;
assert_eq!(
entry_len,
key_share_data.len() - 2,
"key_share list length mismatch"
);
let group = u16::from_be_bytes([key_share_data[2], key_share_data[3]]);
let key_len = u16::from_be_bytes([key_share_data[4], key_share_data[5]]) as usize;
let key = &key_share_data[6..6 + key_len];
assert_eq!(group, 0x001d, "key_share group must be x25519");
assert_eq!(key_len, 32, "x25519 key length must be 32");
assert!(
key.iter().any(|b| *b != 0),
"x25519 key must not be all zero"
);
}
#[test]
fn test_fallback_profiles_keep_compat_extension_set() {
let rng = SecureRandom::new();
for profile in [
TlsFetchProfile::ModernFirefoxLike,
TlsFetchProfile::CompatTls12,
TlsFetchProfile::LegacyMinimal,
] {
let hello = build_client_hello("fallback.example", &rng, profile, false, true);
let parsed = parse_client_hello_for_test(&hello);
let extension_ids = parsed
.extensions
.iter()
.map(|(ext_type, _)| *ext_type)
.collect::<Vec<_>>();
assert!(extension_ids.contains(&0x0000), "SNI extension must exist");
assert!(
extension_ids.contains(&0x000a),
"supported_groups extension must exist"
);
assert!(
extension_ids.contains(&0x000d),
"signature_algorithms extension must exist"
);
assert!(
extension_ids.contains(&0x002b),
"supported_versions extension must exist"
);
assert!(
extension_ids.contains(&0x0033),
"key_share extension must exist"
);
assert!(extension_ids.contains(&0x0010), "ALPN extension must exist");
assert!(
!extension_ids.contains(&0x000b),
"ec_point_formats must stay chrome-only"
);
assert!(
!extension_ids.contains(&0x0023),
"session_ticket must stay chrome-only"
);
assert!(
!extension_ids.contains(&0x002d),
"psk_key_exchange_modes must stay chrome-only"
);
let expected_session_len = if matches!(profile, TlsFetchProfile::ModernFirefoxLike) {
32
} else {
0
};
assert_eq!(
parsed.session_id.len(),
expected_session_len,
"unexpected session id length for {}",
profile.as_str()
);
}
}
#[tokio::test(flavor = "current_thread")]
async fn test_rustls_client_hello_alpn_matches_selected_profile() {
for profile in [
TlsFetchProfile::ModernChromeLike,
TlsFetchProfile::CompatTls12,
TlsFetchProfile::LegacyMinimal,
] {
let record = capture_rustls_client_hello_record(profile_alpn(profile)).await;
let parsed = parse_client_hello_for_test(&record);
let alpn_ext = parsed
.extensions
.iter()
.find(|(ext_type, _)| *ext_type == 0x0010)
.expect("ALPN extension must exist");
let parsed_alpn = parse_alpn_protocols(&alpn_ext.1);
let expected_alpn = profile_alpn(profile)
.iter()
.map(|proto| proto.to_vec())
.collect::<Vec<_>>();
assert_eq!(
parsed_alpn,
expected_alpn,
"rustls ALPN mismatch for {}",
profile.as_str()
);
}
}
#[test]
fn test_build_tls_fetch_proxy_header_v2_with_tcp_addrs() {
let src: SocketAddr = "198.51.100.10:42000".parse().expect("valid src");

View File

@@ -4,6 +4,7 @@ use crate::crypto::SecureRandom;
use crate::protocol::constants::{
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE,
};
use crate::protocol::tls::ClientHelloTlsVersion;
use crate::tls_front::emulator::build_emulated_server_hello;
use crate::tls_front::types::{
CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsProfileSource,
@@ -62,6 +63,8 @@ fn emulated_server_hello_keeps_single_change_cipher_spec_for_client_compatibilit
&[0x72; 16],
&cached,
false,
true,
ClientHelloTlsVersion::Tls13,
&rng,
None,
0,
@@ -84,6 +87,8 @@ fn emulated_server_hello_does_not_emit_profile_ticket_tail_when_disabled() {
&[0x82; 16],
&cached,
false,
true,
ClientHelloTlsVersion::Tls13,
&rng,
None,
0,
@@ -104,6 +109,8 @@ fn emulated_server_hello_uses_profile_ticket_lengths_when_enabled() {
&[0x92; 16],
&cached,
false,
true,
ClientHelloTlsVersion::Tls13,
&rng,
None,
2,

View File

@@ -4,6 +4,7 @@ use crate::crypto::SecureRandom;
use crate::protocol::constants::{
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE,
};
use crate::protocol::tls::ClientHelloTlsVersion;
use crate::tls_front::emulator::build_emulated_server_hello;
use crate::tls_front::types::{
CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsCertPayload, TlsProfileSource,
@@ -55,6 +56,8 @@ fn emulated_server_hello_ignores_oversized_alpn_when_marker_would_not_fit() {
&[0x22; 16],
&cached,
true,
true,
ClientHelloTlsVersion::Tls13,
&rng,
Some(oversized_alpn),
0,
@@ -91,6 +94,8 @@ fn emulated_server_hello_embeds_full_alpn_marker_when_body_can_fit() {
&[0x41; 16],
&cached,
true,
true,
ClientHelloTlsVersion::Tls13,
&rng,
Some(b"h2".to_vec()),
0,
@@ -119,6 +124,8 @@ fn emulated_server_hello_prefers_cert_payload_over_alpn_marker() {
&[0x42; 16],
&cached,
true,
true,
ClientHelloTlsVersion::Tls12,
&rng,
Some(b"h2".to_vec()),
0,

View File

@@ -12,6 +12,7 @@ pub(crate) struct PressureSignals {
#[derive(Debug, Clone)]
pub(crate) struct PressureConfig {
pub(crate) backpressure_enabled: bool,
pub(crate) evaluate_every_rounds: u32,
pub(crate) transition_hysteresis_rounds: u8,
pub(crate) standing_ratio_pressured_pct: u8,
@@ -32,6 +33,7 @@ pub(crate) struct PressureConfig {
impl Default for PressureConfig {
fn default() -> Self {
Self {
backpressure_enabled: true,
evaluate_every_rounds: 8,
transition_hysteresis_rounds: 3,
standing_ratio_pressured_pct: 20,
@@ -99,6 +101,13 @@ impl PressureEvaluator {
force: bool,
) -> PressureState {
self.rotate_window_if_needed(now, cfg);
if !cfg.backpressure_enabled {
self.state = PressureState::Normal;
self.candidate_state = PressureState::Normal;
self.candidate_hits = 0;
self.rounds_since_eval = 0;
return self.state;
}
self.rounds_since_eval = self.rounds_since_eval.saturating_add(1);
if !force && self.rounds_since_eval < cfg.evaluate_every_rounds.max(1) {
return self.state;
@@ -133,6 +142,10 @@ impl PressureEvaluator {
max_total_queued_bytes: u64,
signals: PressureSignals,
) -> PressureState {
if !cfg.backpressure_enabled {
return PressureState::Normal;
}
let queue_ratio_pct = if max_total_queued_bytes == 0 {
100
} else {
@@ -146,57 +159,59 @@ impl PressureEvaluator {
((signals.standing_flows.saturating_mul(100)) / signals.active_flows).min(100) as u8
};
let mut pressured = false;
let mut saturated = false;
let mut pressure_score = 0u8;
let queue_saturated_pct = cfg
.queue_ratio_shedding_pct
.min(cfg.queue_ratio_saturated_pct);
if queue_ratio_pct >= cfg.queue_ratio_pressured_pct {
pressured = true;
pressure_score = pressure_score.max(1);
}
if queue_ratio_pct >= queue_saturated_pct {
saturated = true;
if queue_ratio_pct >= cfg.queue_ratio_shedding_pct {
pressure_score = pressure_score.max(2);
}
if queue_ratio_pct >= cfg.queue_ratio_saturated_pct {
pressure_score = pressure_score.max(3);
}
let standing_saturated_pct = cfg
.standing_ratio_shedding_pct
.min(cfg.standing_ratio_saturated_pct);
if standing_ratio_pct >= cfg.standing_ratio_pressured_pct {
pressured = true;
pressure_score = pressure_score.max(1);
}
if standing_ratio_pct >= standing_saturated_pct {
saturated = true;
if standing_ratio_pct >= cfg.standing_ratio_shedding_pct {
pressure_score = pressure_score.max(2);
}
if standing_ratio_pct >= cfg.standing_ratio_saturated_pct {
pressure_score = pressure_score.max(3);
}
let rejects_saturated = cfg.rejects_shedding.min(cfg.rejects_saturated);
if self.admission_rejects_window >= cfg.rejects_pressured {
pressured = true;
pressure_score = pressure_score.max(1);
}
if self.admission_rejects_window >= rejects_saturated {
saturated = true;
if self.admission_rejects_window >= cfg.rejects_shedding {
pressure_score = pressure_score.max(2);
}
if self.admission_rejects_window >= cfg.rejects_saturated {
pressure_score = pressure_score.max(3);
}
let stalls_saturated = cfg.stalls_shedding.min(cfg.stalls_saturated);
if self.route_stalls_window >= cfg.stalls_pressured {
pressured = true;
pressure_score = pressure_score.max(1);
}
if self.route_stalls_window >= stalls_saturated {
saturated = true;
if self.route_stalls_window >= cfg.stalls_shedding {
pressure_score = pressure_score.max(2);
}
if self.route_stalls_window >= cfg.stalls_saturated {
pressure_score = pressure_score.max(3);
}
if signals.backpressured_flows > signals.active_flows.saturating_div(2)
&& signals.active_flows > 0
{
pressured = true;
pressure_score = pressure_score.max(2);
}
if saturated {
PressureState::Saturated
} else if pressured {
PressureState::Pressured
} else {
PressureState::Normal
match pressure_score {
0 => PressureState::Normal,
1 => PressureState::Pressured,
2 => PressureState::Shedding,
_ => PressureState::Saturated,
}
}

View File

@@ -14,6 +14,7 @@ use super::pressure::{PressureConfig, PressureEvaluator, PressureSignals};
#[derive(Debug, Clone)]
pub(crate) struct WorkerFairnessConfig {
pub(crate) worker_id: u16,
pub(crate) backpressure_enabled: bool,
pub(crate) max_active_flows: usize,
pub(crate) max_total_queued_bytes: u64,
pub(crate) max_flow_queued_bytes: u64,
@@ -36,6 +37,7 @@ impl Default for WorkerFairnessConfig {
fn default() -> Self {
Self {
worker_id: 0,
backpressure_enabled: true,
max_active_flows: 4096,
max_total_queued_bytes: 16 * 1024 * 1024,
max_flow_queued_bytes: 512 * 1024,
@@ -107,7 +109,8 @@ pub(crate) struct WorkerFairnessState {
}
impl WorkerFairnessState {
pub(crate) fn new(config: WorkerFairnessConfig, now: Instant) -> Self {
pub(crate) fn new(mut config: WorkerFairnessConfig, now: Instant) -> Self {
config.pressure.backpressure_enabled = config.backpressure_enabled;
let bucket_count = config.soft_bucket_count.max(1);
Self {
config,
@@ -134,6 +137,15 @@ impl WorkerFairnessState {
self.pressure.state()
}
pub(crate) fn set_backpressure_enabled(&mut self, enabled: bool) {
if self.config.backpressure_enabled == enabled {
return;
}
self.config.backpressure_enabled = enabled;
self.config.pressure.backpressure_enabled = enabled;
self.evaluate_pressure(Instant::now(), true);
}
pub(crate) fn snapshot(&self) -> WorkerFairnessSnapshot {
WorkerFairnessSnapshot {
pressure_state: self.pressure.state(),
@@ -166,7 +178,7 @@ impl WorkerFairnessState {
};
let frame_bytes = frame.queued_bytes();
if self.pressure.state() == PressureState::Saturated {
if self.config.backpressure_enabled && self.pressure.state() == PressureState::Saturated {
self.pressure
.note_admission_reject(now, &self.config.pressure);
self.enqueue_rejects = self.enqueue_rejects.saturating_add(1);
@@ -231,7 +243,8 @@ impl WorkerFairnessState {
return AdmissionDecision::RejectFlowCap;
}
if self.pressure.state() >= PressureState::Shedding
if self.config.backpressure_enabled
&& self.pressure.state() >= PressureState::Shedding
&& entry.fairness.standing_state == StandingQueueState::Standing
{
self.pressure
@@ -422,8 +435,10 @@ impl WorkerFairnessState {
DispatchAction::Continue
}
DispatchFeedback::QueueFull => {
self.pressure.note_route_stall(now, &self.config.pressure);
self.downstream_stalls = self.downstream_stalls.saturating_add(1);
if self.config.backpressure_enabled {
self.pressure.note_route_stall(now, &self.config.pressure);
self.downstream_stalls = self.downstream_stalls.saturating_add(1);
}
let state = self.pressure.state();
let Some(flow) = self.flows.get_mut(&conn_id) else {
self.evaluate_pressure(now, true);
@@ -433,16 +448,19 @@ impl WorkerFairnessState {
let before_membership = Self::flow_membership(&flow.fairness);
let mut enqueue_active = false;
flow.fairness.consecutive_stalls =
flow.fairness.consecutive_stalls.saturating_add(1);
flow.fairness.scheduler_state = FlowSchedulerState::Backpressured;
flow.fairness.pressure_class = FlowPressureClass::Backpressured;
if self.config.backpressure_enabled {
flow.fairness.consecutive_stalls =
flow.fairness.consecutive_stalls.saturating_add(1);
flow.fairness.scheduler_state = FlowSchedulerState::Backpressured;
flow.fairness.pressure_class = FlowPressureClass::Backpressured;
}
let should_shed_frame = matches!(state, PressureState::Saturated)
|| (matches!(state, PressureState::Shedding)
&& flow.fairness.standing_state == StandingQueueState::Standing
&& flow.fairness.consecutive_stalls
>= self.config.max_consecutive_stalls_before_shed);
let should_shed_frame = self.config.backpressure_enabled
&& (matches!(state, PressureState::Saturated)
|| (matches!(state, PressureState::Shedding)
&& flow.fairness.standing_state == StandingQueueState::Standing
&& flow.fairness.consecutive_stalls
>= self.config.max_consecutive_stalls_before_shed));
if should_shed_frame {
self.shed_drops = self.shed_drops.saturating_add(1);
@@ -467,8 +485,9 @@ impl WorkerFairnessState {
Self::classify_flow(&self.config, state, now, &mut flow.fairness);
let after_membership = Self::flow_membership(&flow.fairness);
let should_close_flow = flow.fairness.consecutive_stalls
>= self.config.max_consecutive_stalls_before_close
let should_close_flow = self.config.backpressure_enabled
&& flow.fairness.consecutive_stalls
>= self.config.max_consecutive_stalls_before_close
&& self.pressure.state() == PressureState::Saturated;
(
before_membership,

View File

@@ -1794,6 +1794,8 @@ mod tests {
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_enabled,
general.me_route_fairshare_enabled,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,

View File

@@ -46,6 +46,7 @@ mod send_adversarial_tests;
mod wire;
use bytes::Bytes;
use tokio::sync::OwnedSemaphorePermit;
#[allow(unused_imports)]
pub use config_updater::{
@@ -68,9 +69,32 @@ pub use secret::{fetch_proxy_secret, fetch_proxy_secret_with_upstream};
pub(crate) use selftest::{bnd_snapshot, timeskew_snapshot, upstream_bnd_snapshots};
pub use wire::proto_flags_for_tag;
/// Holds D2C queued-byte capacity until a routed payload is consumed or dropped.
pub struct RouteBytePermit {
_permit: OwnedSemaphorePermit,
}
impl std::fmt::Debug for RouteBytePermit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RouteBytePermit").finish_non_exhaustive()
}
}
impl RouteBytePermit {
pub(crate) fn new(permit: OwnedSemaphorePermit) -> Self {
Self { _permit: permit }
}
}
/// Response routed from middle proxy readers to client relay tasks.
#[derive(Debug)]
pub enum MeResponse {
Data { flags: u32, data: Bytes },
/// Downstream payload with its queued-byte reservation.
Data {
flags: u32,
data: Bytes,
route_permit: Option<RouteBytePermit>,
},
Ack(u32),
Close,
}

View File

@@ -396,6 +396,8 @@ pub(super) struct WriterSelectionPolicyCore {
pub(super) struct TransportPolicyCore {
pub(super) me_socks_kdf_policy: AtomicU8,
pub(super) me_route_backpressure_enabled: Arc<AtomicBool>,
pub(super) me_route_fairshare_enabled: Arc<AtomicBool>,
pub(super) me_reader_route_data_wait_ms: Arc<AtomicU64>,
}
@@ -548,6 +550,8 @@ impl MePool {
me_socks_kdf_policy: MeSocksKdfPolicy,
me_writer_cmd_channel_capacity: usize,
me_route_channel_capacity: usize,
me_route_backpressure_enabled: bool,
me_route_fairshare_enabled: bool,
me_route_backpressure_base_timeout_ms: u64,
me_route_backpressure_high_timeout_ms: u64,
me_route_backpressure_high_watermark_pct: u8,
@@ -783,6 +787,10 @@ impl MePool {
}),
transport_policy: Arc::new(TransportPolicyCore {
me_socks_kdf_policy: AtomicU8::new(me_socks_kdf_policy.as_u8()),
me_route_backpressure_enabled: Arc::new(AtomicBool::new(
me_route_backpressure_enabled,
)),
me_route_fairshare_enabled: Arc::new(AtomicBool::new(me_route_fairshare_enabled)),
me_reader_route_data_wait_ms: Arc::new(AtomicU64::new(
me_reader_route_data_wait_ms,
)),
@@ -1245,6 +1253,8 @@ impl MePool {
pub fn update_runtime_transport_policy(
&self,
socks_kdf_policy: MeSocksKdfPolicy,
route_backpressure_enabled: bool,
route_fairshare_enabled: bool,
route_backpressure_base_timeout_ms: u64,
route_backpressure_high_timeout_ms: u64,
route_backpressure_high_watermark_pct: u8,
@@ -1253,6 +1263,12 @@ impl MePool {
self.transport_policy
.me_socks_kdf_policy
.store(socks_kdf_policy.as_u8(), Ordering::Relaxed);
self.transport_policy
.me_route_backpressure_enabled
.store(route_backpressure_enabled, Ordering::Relaxed);
self.transport_policy
.me_route_fairshare_enabled
.store(route_fairshare_enabled, Ordering::Relaxed);
self.transport_policy
.me_reader_route_data_wait_ms
.store(reader_route_data_wait_ms, Ordering::Relaxed);

View File

@@ -436,6 +436,9 @@ impl MePool {
let cancel_signal = cancel.clone();
let cancel_select = cancel.clone();
let cancel_cleanup = cancel.clone();
let route_backpressure_enabled =
self.transport_policy.me_route_backpressure_enabled.clone();
let route_fairshare_enabled = self.transport_policy.me_route_fairshare_enabled.clone();
let reader_route_data_wait_ms = self.transport_policy.me_reader_route_data_wait_ms.clone();
tokio::spawn(async move {
@@ -458,6 +461,8 @@ impl MePool {
writer_id,
degraded,
rtt_ema_ms_x10,
route_backpressure_enabled,
route_fairshare_enabled,
reader_route_data_wait_ms,
cancel_reader,
) => WriterLifecycleExit::Reader(reader_res),

View File

@@ -45,7 +45,15 @@ fn is_data_route_queue_full(result: RouteResult) -> bool {
)
}
fn should_close_on_queue_full_streak(streak: u8, pressure_state: PressureState) -> bool {
fn should_close_on_queue_full_streak_with_policy(
streak: u8,
pressure_state: PressureState,
backpressure_enabled: bool,
) -> bool {
if !backpressure_enabled {
return false;
}
if pressure_state < PressureState::Shedding {
return false;
}
@@ -76,6 +84,7 @@ async fn route_data_with_retry(
MeResponse::Data {
flags,
data: data.clone(),
route_permit: None,
},
timeout_ms,
)
@@ -160,6 +169,7 @@ async fn drain_fairness_scheduler(
reg: &ConnRegistry,
tx: &mpsc::Sender<WriterCommand>,
data_route_queue_full_streak: &mut HashMap<u64, u8>,
backpressure_enabled: bool,
route_wait_ms: u64,
stats: &Stats,
) {
@@ -188,7 +198,11 @@ async fn drain_fairness_scheduler(
if is_data_route_queue_full(routed) {
let streak = data_route_queue_full_streak.entry(cid).or_insert(0);
*streak = streak.saturating_add(1);
if should_close_on_queue_full_streak(*streak, pressure_state) {
if should_close_on_queue_full_streak_with_policy(
*streak,
pressure_state,
backpressure_enabled,
) {
fairness.remove_flow(cid);
data_route_queue_full_streak.remove(&cid);
reg.unregister(cid).await;
@@ -220,6 +234,8 @@ pub(crate) async fn reader_loop(
writer_id: u64,
degraded: Arc<AtomicBool>,
writer_rtt_ema_ms_x10: Arc<AtomicU32>,
route_backpressure_enabled: Arc<AtomicBool>,
route_fairshare_enabled: Arc<AtomicBool>,
reader_route_data_wait_ms: Arc<AtomicU64>,
cancel: CancellationToken,
) -> Result<()> {
@@ -236,14 +252,19 @@ pub(crate) async fn reader_loop(
max_flow_queued_bytes: (reg.route_channel_capacity() as u64)
.saturating_mul(2 * 1024)
.clamp(64 * 1024, 2 * 1024 * 1024),
backpressure_enabled: route_backpressure_enabled.load(Ordering::Relaxed),
..WorkerFairnessConfig::default()
},
Instant::now(),
);
let mut fairness_snapshot = fairness.snapshot();
loop {
let backpressure_enabled = route_backpressure_enabled.load(Ordering::Relaxed);
let fairshare_enabled = route_fairshare_enabled.load(Ordering::Relaxed);
fairness.set_backpressure_enabled(backpressure_enabled);
let fairness_has_backlog = should_schedule_fairness_retry(&fairness_snapshot);
let mut tmp = [0u8; 65_536];
let backlog_retry_enabled = should_schedule_fairness_retry(&fairness_snapshot);
let backlog_retry_enabled = fairness_has_backlog;
let backlog_retry_delay =
fairness_retry_delay(reader_route_data_wait_ms.load(Ordering::Relaxed));
let mut retry_only = false;
@@ -262,6 +283,7 @@ pub(crate) async fn reader_loop(
reg.as_ref(),
&tx,
&mut data_route_queue_full_streak,
backpressure_enabled,
route_wait_ms,
stats.as_ref(),
)
@@ -346,20 +368,56 @@ pub(crate) async fn reader_loop(
let data = body.slice(12..);
trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS");
let admission = fairness.enqueue_data(cid, flags, data, Instant::now());
if !matches!(admission, AdmissionDecision::Admit) {
stats.increment_me_route_drop_queue_full();
stats.increment_me_route_drop_queue_full_high();
let streak = data_route_queue_full_streak.entry(cid).or_insert(0);
*streak = streak.saturating_add(1);
let pressure_state = fairness.pressure_state();
if should_close_on_queue_full_streak(*streak, pressure_state)
|| matches!(admission, AdmissionDecision::RejectSaturated)
{
if fairshare_enabled {
let admission = fairness.enqueue_data(cid, flags, data, Instant::now());
if !matches!(admission, AdmissionDecision::Admit) {
stats.increment_me_route_drop_queue_full();
stats.increment_me_route_drop_queue_full_high();
let streak = data_route_queue_full_streak.entry(cid).or_insert(0);
*streak = streak.saturating_add(1);
let pressure_state = fairness.pressure_state();
if should_close_on_queue_full_streak_with_policy(
*streak,
pressure_state,
backpressure_enabled,
) || (backpressure_enabled
&& matches!(admission, AdmissionDecision::RejectSaturated))
{
fairness.remove_flow(cid);
data_route_queue_full_streak.remove(&cid);
reg.unregister(cid).await;
send_close_conn(&tx, cid).await;
}
}
} else {
let route_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed);
let routed =
route_data_with_retry(reg.as_ref(), cid, flags, data, route_wait_ms).await;
if matches!(routed, RouteResult::Routed) {
data_route_queue_full_streak.remove(&cid);
continue;
}
report_route_drop(routed, stats.as_ref());
if should_close_on_route_result_for_data(routed) {
fairness.remove_flow(cid);
data_route_queue_full_streak.remove(&cid);
reg.unregister(cid).await;
send_close_conn(&tx, cid).await;
continue;
}
if is_data_route_queue_full(routed) {
let streak = data_route_queue_full_streak.entry(cid).or_insert(0);
*streak = streak.saturating_add(1);
if should_close_on_queue_full_streak_with_policy(
*streak,
PressureState::Shedding,
backpressure_enabled,
) {
fairness.remove_flow(cid);
data_route_queue_full_streak.remove(&cid);
reg.unregister(cid).await;
send_close_conn(&tx, cid).await;
}
}
}
} else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 {
@@ -465,6 +523,7 @@ pub(crate) async fn reader_loop(
reg.as_ref(),
&tx,
&mut data_route_queue_full_streak,
backpressure_enabled,
route_wait_ms,
stats.as_ref(),
)
@@ -486,9 +545,9 @@ mod tests {
use super::{
MeResponse, RouteResult, WorkerFairnessSnapshot, fairness_retry_delay,
is_data_route_queue_full, route_data_with_retry, should_close_on_queue_full_streak,
should_close_on_route_result_for_ack, should_close_on_route_result_for_data,
should_schedule_fairness_retry,
is_data_route_queue_full, route_data_with_retry,
should_close_on_queue_full_streak_with_policy, should_close_on_route_result_for_ack,
should_close_on_route_result_for_data, should_schedule_fairness_retry,
};
#[test]
@@ -511,22 +570,35 @@ mod tests {
assert!(is_data_route_queue_full(RouteResult::QueueFullBase));
assert!(is_data_route_queue_full(RouteResult::QueueFullHigh));
assert!(!is_data_route_queue_full(RouteResult::NoConn));
assert!(!should_close_on_queue_full_streak(1, PressureState::Normal));
assert!(!should_close_on_queue_full_streak(
assert!(!should_close_on_queue_full_streak_with_policy(
1,
PressureState::Normal,
true
));
assert!(!should_close_on_queue_full_streak_with_policy(
2,
PressureState::Pressured
PressureState::Pressured,
true
));
assert!(!should_close_on_queue_full_streak(
assert!(!should_close_on_queue_full_streak_with_policy(
3,
PressureState::Pressured
PressureState::Pressured,
true
));
assert!(should_close_on_queue_full_streak(
assert!(should_close_on_queue_full_streak_with_policy(
3,
PressureState::Shedding
PressureState::Shedding,
true
));
assert!(should_close_on_queue_full_streak(
assert!(should_close_on_queue_full_streak_with_policy(
u8::MAX,
PressureState::Saturated
PressureState::Saturated,
true
));
assert!(!should_close_on_queue_full_streak_with_policy(
u8::MAX,
PressureState::Saturated,
false
));
}
@@ -568,7 +640,7 @@ mod tests {
let routed = route_data_with_retry(&reg, conn_id, 0, Bytes::from_static(b"a"), 20).await;
assert!(matches!(routed, RouteResult::Routed));
match rx.recv().await {
Some(MeResponse::Data { flags, data }) => {
Some(MeResponse::Data { flags, data, .. }) => {
assert_eq!(flags, 0);
assert_eq!(data, Bytes::from_static(b"a"));
}

View File

@@ -1,18 +1,22 @@
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use dashmap::DashMap;
use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::{Mutex, mpsc};
use tokio::sync::{Mutex, Semaphore, mpsc};
use super::MeResponse;
use super::codec::WriterCommand;
use super::{MeResponse, RouteBytePermit};
const ROUTE_BACKPRESSURE_BASE_TIMEOUT_MS: u64 = 25;
const ROUTE_BACKPRESSURE_HIGH_TIMEOUT_MS: u64 = 120;
const ROUTE_BACKPRESSURE_HIGH_WATERMARK_PCT: u8 = 80;
const ROUTE_QUEUED_BYTE_PERMIT_UNIT: usize = 16 * 1024;
const ROUTE_QUEUED_PERMITS_PER_SLOT: usize = 4;
const ROUTE_QUEUED_MAX_FRAME_PERMITS: usize = 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RouteResult {
@@ -53,6 +57,7 @@ pub(super) struct WriterActivitySnapshot {
struct RoutingTable {
map: DashMap<u64, mpsc::Sender<MeResponse>>,
byte_budget: DashMap<u64, Arc<Semaphore>>,
}
struct WriterTable {
@@ -105,6 +110,7 @@ pub struct ConnRegistry {
route_backpressure_base_timeout_ms: AtomicU64,
route_backpressure_high_timeout_ms: AtomicU64,
route_backpressure_high_watermark_pct: AtomicU8,
route_byte_permits_per_conn: usize,
}
impl ConnRegistry {
@@ -116,10 +122,23 @@ impl ConnRegistry {
}
pub fn with_route_channel_capacity(route_channel_capacity: usize) -> Self {
let route_channel_capacity = route_channel_capacity.max(1);
Self::with_route_limits(
route_channel_capacity,
Self::route_byte_permit_budget(route_channel_capacity),
)
}
fn with_route_limits(
route_channel_capacity: usize,
route_byte_permits_per_conn: usize,
) -> Self {
let start = rand::random::<u64>() | 1;
let route_channel_capacity = route_channel_capacity.max(1);
Self {
routing: RoutingTable {
map: DashMap::new(),
byte_budget: DashMap::new(),
},
writers: WriterTable {
map: DashMap::new(),
@@ -131,15 +150,30 @@ impl ConnRegistry {
inner: Mutex::new(BindingInner::new()),
},
next_id: AtomicU64::new(start),
route_channel_capacity: route_channel_capacity.max(1),
route_channel_capacity,
route_backpressure_base_timeout_ms: AtomicU64::new(ROUTE_BACKPRESSURE_BASE_TIMEOUT_MS),
route_backpressure_high_timeout_ms: AtomicU64::new(ROUTE_BACKPRESSURE_HIGH_TIMEOUT_MS),
route_backpressure_high_watermark_pct: AtomicU8::new(
ROUTE_BACKPRESSURE_HIGH_WATERMARK_PCT,
),
route_byte_permits_per_conn: route_byte_permits_per_conn.max(1),
}
}
fn route_data_permits(data_len: usize) -> u32 {
data_len
.max(1)
.div_ceil(ROUTE_QUEUED_BYTE_PERMIT_UNIT)
.min(u32::MAX as usize) as u32
}
fn route_byte_permit_budget(route_channel_capacity: usize) -> usize {
route_channel_capacity
.saturating_mul(ROUTE_QUEUED_PERMITS_PER_SLOT)
.max(ROUTE_QUEUED_MAX_FRAME_PERMITS)
.max(1)
}
pub fn route_channel_capacity(&self) -> usize {
self.route_channel_capacity
}
@@ -149,6 +183,14 @@ impl ConnRegistry {
Self::with_route_channel_capacity(4096)
}
#[cfg(test)]
fn with_route_byte_permits_for_tests(
route_channel_capacity: usize,
route_byte_permits_per_conn: usize,
) -> Self {
Self::with_route_limits(route_channel_capacity, route_byte_permits_per_conn)
}
pub fn update_route_backpressure_policy(
&self,
base_timeout_ms: u64,
@@ -170,6 +212,10 @@ impl ConnRegistry {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = mpsc::channel(self.route_channel_capacity);
self.routing.map.insert(id, tx);
self.routing.byte_budget.insert(
id,
Arc::new(Semaphore::new(self.route_byte_permits_per_conn)),
);
(id, rx)
}
@@ -186,6 +232,7 @@ impl ConnRegistry {
/// Unregister connection, returning associated writer_id if any.
pub async fn unregister(&self, id: u64) -> Option<u64> {
self.routing.map.remove(&id);
self.routing.byte_budget.remove(&id);
self.hot_binding.map.remove(&id);
let mut binding = self.binding.inner.lock().await;
binding.meta.remove(&id);
@@ -206,6 +253,64 @@ impl ConnRegistry {
None
}
async fn attach_route_byte_permit(
&self,
id: u64,
resp: MeResponse,
timeout_ms: Option<u64>,
) -> std::result::Result<MeResponse, RouteResult> {
let MeResponse::Data {
flags,
data,
route_permit,
} = resp
else {
return Ok(resp);
};
if route_permit.is_some() {
return Ok(MeResponse::Data {
flags,
data,
route_permit,
});
}
let Some(semaphore) = self
.routing
.byte_budget
.get(&id)
.map(|entry| entry.value().clone())
else {
return Err(RouteResult::NoConn);
};
let permits = Self::route_data_permits(data.len());
let permit = match timeout_ms {
Some(0) => semaphore
.try_acquire_many_owned(permits)
.map_err(|_| RouteResult::QueueFullHigh)?,
Some(timeout_ms) => {
let acquire = semaphore.acquire_many_owned(permits);
match tokio::time::timeout(Duration::from_millis(timeout_ms.max(1)), acquire).await
{
Ok(Ok(permit)) => permit,
Ok(Err(_)) => return Err(RouteResult::ChannelClosed),
Err(_) => return Err(RouteResult::QueueFullHigh),
}
}
None => semaphore
.acquire_many_owned(permits)
.await
.map_err(|_| RouteResult::ChannelClosed)?,
};
Ok(MeResponse::Data {
flags,
data,
route_permit: Some(RouteBytePermit::new(permit)),
})
}
#[allow(dead_code)]
pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult {
let tx = self.routing.map.get(&id).map(|entry| entry.value().clone());
@@ -214,15 +319,23 @@ impl ConnRegistry {
return RouteResult::NoConn;
};
let base_timeout_ms = self
.route_backpressure_base_timeout_ms
.load(Ordering::Relaxed)
.max(1);
let resp = match self
.attach_route_byte_permit(id, resp, Some(base_timeout_ms))
.await
{
Ok(resp) => resp,
Err(result) => return result,
};
match tx.try_send(resp) {
Ok(()) => RouteResult::Routed,
Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed,
Err(TrySendError::Full(resp)) => {
// Absorb short bursts without dropping/closing the session immediately.
let base_timeout_ms = self
.route_backpressure_base_timeout_ms
.load(Ordering::Relaxed)
.max(1);
let high_timeout_ms = self
.route_backpressure_high_timeout_ms
.load(Ordering::Relaxed)
@@ -266,6 +379,10 @@ impl ConnRegistry {
let Some(tx) = tx else {
return RouteResult::NoConn;
};
let resp = match self.attach_route_byte_permit(id, resp, Some(0)).await {
Ok(resp) => resp,
Err(result) => return result,
};
match tx.try_send(resp) {
Ok(()) => RouteResult::Routed,
@@ -289,6 +406,13 @@ impl ConnRegistry {
let Some(tx) = tx else {
return RouteResult::NoConn;
};
let resp = match self
.attach_route_byte_permit(id, resp, Some(timeout_ms))
.await
{
Ok(resp) => resp,
Err(result) => return result,
};
match tx.try_send(resp) {
Ok(()) => RouteResult::Routed,
@@ -541,8 +665,10 @@ impl ConnRegistry {
mod tests {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use super::ConnMeta;
use super::ConnRegistry;
use bytes::Bytes;
use super::{ConnMeta, ConnRegistry, RouteResult};
use crate::transport::middle_proxy::MeResponse;
#[tokio::test]
async fn writer_activity_snapshot_tracks_writer_and_dc_load() {
@@ -608,6 +734,55 @@ mod tests {
assert_eq!(snapshot.active_sessions_by_target_dc.get(&4), Some(&1));
}
#[tokio::test]
async fn route_data_is_bounded_by_byte_permits_before_channel_capacity() {
let registry = ConnRegistry::with_route_byte_permits_for_tests(4, 1);
let (conn_id, mut rx) = registry.register().await;
let routed = registry
.route_nowait(
conn_id,
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xAA]),
route_permit: None,
},
)
.await;
assert!(matches!(routed, RouteResult::Routed));
let blocked = registry
.route_nowait(
conn_id,
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xBB]),
route_permit: None,
},
)
.await;
assert!(
matches!(blocked, RouteResult::QueueFullHigh),
"byte budget must reject data before count capacity is exhausted"
);
drop(rx.recv().await);
let routed_after_drain = registry
.route_nowait(
conn_id,
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xCC]),
route_permit: None,
},
)
.await;
assert!(
matches!(routed_after_drain, RouteResult::Routed),
"receiving queued data must release byte permits"
);
}
#[tokio::test]
async fn bind_writer_rebinds_conn_atomically() {
let registry = ConnRegistry::new();

View File

@@ -104,6 +104,8 @@ async fn make_pool(
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_enabled,
general.me_route_fairshare_enabled,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,

View File

@@ -102,6 +102,8 @@ async fn make_pool(
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_enabled,
general.me_route_fairshare_enabled,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,

View File

@@ -97,6 +97,8 @@ async fn make_pool(me_pool_drain_threshold: u64) -> Arc<MePool> {
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_enabled,
general.me_route_fairshare_enabled,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,

View File

@@ -86,6 +86,8 @@ async fn make_pool() -> Arc<MePool> {
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_enabled,
general.me_route_fairshare_enabled,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,

View File

@@ -91,6 +91,8 @@ async fn make_pool() -> Arc<MePool> {
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_enabled,
general.me_route_fairshare_enabled,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,

View File

@@ -97,6 +97,8 @@ async fn make_pool() -> (Arc<MePool>, Arc<SecureRandom>) {
MeSocksKdfPolicy::default(),
general.me_writer_cmd_channel_capacity,
general.me_route_channel_capacity,
general.me_route_backpressure_enabled,
general.me_route_fairshare_enabled,
general.me_route_backpressure_base_timeout_ms,
general.me_route_backpressure_high_timeout_ms,
general.me_route_backpressure_high_watermark_pct,