mirror of
https://github.com/telemt/telemt.git
synced 2026-06-16 23:34:08 +03:00
Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
49742d38a7 | ||
|
|
869d8517a0 | ||
|
|
e82ce634d6 | ||
|
|
f1f46fac42 | ||
|
|
37d0184a0b | ||
|
|
d81d7dba62 | ||
|
|
04b8d8365c | ||
|
|
2e26bfb86e | ||
|
|
d414c73c9b | ||
|
|
b153782597 | ||
|
|
9dc67727b0 | ||
|
|
2d02fbe548 | ||
|
|
2675779915 | ||
|
|
c4954f745f | ||
|
|
f33abfb09e | ||
|
|
9904da737a | ||
|
|
9a3ff726b2 | ||
|
|
942882f9de | ||
|
|
eeff16c3fd | ||
|
|
c86dc2f65e | ||
|
|
1cbde70a14 | ||
|
|
26cd4734de | ||
|
|
52a1b66ad7 |
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -2938,7 +2938,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
|
||||
|
||||
[[package]]
|
||||
name = "telemt"
|
||||
version = "3.4.16"
|
||||
version = "3.4.18"
|
||||
dependencies = [
|
||||
"aes",
|
||||
"anyhow",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "telemt"
|
||||
version = "3.4.16"
|
||||
version = "3.4.18"
|
||||
edition = "2024"
|
||||
|
||||
[features]
|
||||
|
||||
@@ -2219,6 +2219,10 @@ Note: This section also accepts the legacy alias `[server.admin_api]` (same sche
|
||||
| [`ip`](#ip) | `IpAddr` | — | `✘` |
|
||||
| [`port`](#port-serverlisteners) | `u16` | `server.port` | `✘` |
|
||||
| [`client_mss`](#client_mss-serverlisteners) | `String` | `[server].client_mss` | `✘` |
|
||||
| [`synlimit`](#synlimit-serverlisteners) | `false`, `"iptables"`, or `"nftables"` | `false` | `✔` |
|
||||
| [`synlimit_seconds`](#synlimit_seconds-serverlisteners) | `u32` | `1` | `✔` |
|
||||
| [`synlimit_hitcount`](#synlimit_hitcount-serverlisteners) | `u32` | `1` | `✔` |
|
||||
| [`synlimit_burst`](#synlimit_burst-serverlisteners) | `u32` | `2` | `✔` |
|
||||
| [`announce`](#announce) | `String` | — | `✘` |
|
||||
| [`announce_ip`](#announce_ip) | `IpAddr` | — | `✘` |
|
||||
| [`proxy_protocol`](#proxy_protocol) | `bool` | — | `✘` |
|
||||
@@ -2254,6 +2258,58 @@ Note: This section also accepts the legacy alias `[server.admin_api]` (same sche
|
||||
port = 443
|
||||
client_mss = "256"
|
||||
```
|
||||
## synlimit (server.listeners)
|
||||
- **Constraints / validation**: `false`, `"iptables"`, or `"nftables"`. Omitted or `false` disables SYN limiting for this listener.
|
||||
- **Description**: Installs per-listener Linux netfilter SYN limiter rules for the listener port. `"iptables"` uses `iptables`/`ip6tables` filter rules with the `hashlimit` match as a per-source token bucket. `"nftables"` uses per-source `meter` rules with `limit rate over` and auto-detects whether the host already uses `inet`, `ip`, or `ip6` table families before creating Telemt-owned tables. The token-bucket rate is `synlimit_hitcount / synlimit_seconds`; `synlimit_burst` controls the burst size. Rules are reconciled at runtime and removed during graceful Telemt shutdown; `SIGKILL` cannot be cleaned up by the process. Requires CAP_NET_ADMIN. `synlimit*` changes hot-reload for existing listener endpoints; changing listener `ip` or `port` still requires restart/rebind.
|
||||
- **Example**:
|
||||
|
||||
```toml
|
||||
[[server.listeners]]
|
||||
ip = "0.0.0.0"
|
||||
port = 443
|
||||
synlimit = "iptables"
|
||||
|
||||
[[server.listeners]]
|
||||
ip = "::"
|
||||
port = 443
|
||||
synlimit = "nftables"
|
||||
```
|
||||
## synlimit_seconds (server.listeners)
|
||||
- **Constraints / validation**: `u32`, must be `> 0`. Default is `1`.
|
||||
- **Description**: Token-bucket interval for both SYN limiter backends. The rate is `synlimit_hitcount / synlimit_seconds` and is rendered to native netfilter rate units (`second`, `minute`, `hour`, or `day`).
|
||||
- **Example**:
|
||||
|
||||
```toml
|
||||
[[server.listeners]]
|
||||
ip = "0.0.0.0"
|
||||
port = 443
|
||||
synlimit = "iptables"
|
||||
synlimit_seconds = 1
|
||||
```
|
||||
## synlimit_hitcount (server.listeners)
|
||||
- **Constraints / validation**: `u32`, must be `> 0`. Default is `1`.
|
||||
- **Description**: Token-bucket rate amount for both SYN limiter backends. Together with `synlimit_seconds`, it defines the allowed source-IP SYN rate before excess SYN packets are dropped.
|
||||
- **Example**:
|
||||
|
||||
```toml
|
||||
[[server.listeners]]
|
||||
ip = "0.0.0.0"
|
||||
port = 443
|
||||
synlimit = "iptables"
|
||||
synlimit_hitcount = 1
|
||||
```
|
||||
## synlimit_burst (server.listeners)
|
||||
- **Constraints / validation**: `u32`, must be `> 0`. Default is `2`.
|
||||
- **Description**: Token-bucket burst size for both SYN limiter backends. Higher values allow short connection bursts from the same source IP before the steady-state `synlimit_hitcount / synlimit_seconds` rate is enforced.
|
||||
- **Example**:
|
||||
|
||||
```toml
|
||||
[[server.listeners]]
|
||||
ip = "0.0.0.0"
|
||||
port = 443
|
||||
synlimit = "iptables"
|
||||
synlimit_burst = 2
|
||||
```
|
||||
## announce
|
||||
- **Constraints / validation**: `String` (optional). Must not be empty when set.
|
||||
- **Description**: Public IP/domain announced in proxy links for this listener. Takes precedence over `announce_ip`.
|
||||
|
||||
@@ -2225,6 +2225,10 @@
|
||||
| [`ip`](#ip) | `IpAddr` | — | `✘` |
|
||||
| [`port`](#port-serverlisteners) | `u16` | `server.port` | `✘` |
|
||||
| [`client_mss`](#client_mss-serverlisteners) | `String` | `[server].client_mss` | `✘` |
|
||||
| [`synlimit`](#synlimit-serverlisteners) | `false`, `"iptables"` или `"nftables"` | `false` | `✔` |
|
||||
| [`synlimit_seconds`](#synlimit_seconds-serverlisteners) | `u32` | `1` | `✔` |
|
||||
| [`synlimit_hitcount`](#synlimit_hitcount-serverlisteners) | `u32` | `1` | `✔` |
|
||||
| [`synlimit_burst`](#synlimit_burst-serverlisteners) | `u32` | `2` | `✔` |
|
||||
| [`announce`](#announce) | `String` | — | `✘` |
|
||||
| [`announce_ip`](#announce_ip) | `IpAddr` | — | `✘` |
|
||||
| [`proxy_protocol`](#proxy_protocol) | `bool` | — | `✘` |
|
||||
@@ -2260,6 +2264,58 @@
|
||||
port = 443
|
||||
client_mss = "256"
|
||||
```
|
||||
## synlimit (server.listeners)
|
||||
- **Ограничения / валидация**: `false`, `"iptables"` или `"nftables"`. Если параметр не задан или задан как `false`, SYN limiter для этого listener’а выключен.
|
||||
- **Описание**: Устанавливает per-listener Linux netfilter SYN limiter rules для порта listener’а. `"iptables"` использует `iptables`/`ip6tables` filter rules с `hashlimit` match как per-source token bucket. `"nftables"` использует per-source `meter` rules с `limit rate over` и автоматически определяет, какие table families уже используются на хосте (`inet`, `ip`, `ip6`), перед созданием Telemt-owned tables. Token-bucket rate равен `synlimit_hitcount / synlimit_seconds`; `synlimit_burst` управляет burst size. Rules reconciled at runtime и удаляются при graceful shutdown Telemt; `SIGKILL` процессом не очищается. Требует CAP_NET_ADMIN. Изменения `synlimit*` hot-reload’ятся для существующих listener endpoints; изменение listener `ip` или `port` по-прежнему требует restart/rebind.
|
||||
- **Пример**:
|
||||
|
||||
```toml
|
||||
[[server.listeners]]
|
||||
ip = "0.0.0.0"
|
||||
port = 443
|
||||
synlimit = "iptables"
|
||||
|
||||
[[server.listeners]]
|
||||
ip = "::"
|
||||
port = 443
|
||||
synlimit = "nftables"
|
||||
```
|
||||
## synlimit_seconds (server.listeners)
|
||||
- **Ограничения / валидация**: `u32`, должно быть `> 0`. Значение по умолчанию: `1`.
|
||||
- **Описание**: Token-bucket interval для обоих SYN limiter backends. Rate равен `synlimit_hitcount / synlimit_seconds` и рендерится в native netfilter rate units (`second`, `minute`, `hour` или `day`).
|
||||
- **Пример**:
|
||||
|
||||
```toml
|
||||
[[server.listeners]]
|
||||
ip = "0.0.0.0"
|
||||
port = 443
|
||||
synlimit = "iptables"
|
||||
synlimit_seconds = 1
|
||||
```
|
||||
## synlimit_hitcount (server.listeners)
|
||||
- **Ограничения / валидация**: `u32`, должно быть `> 0`. Значение по умолчанию: `1`.
|
||||
- **Описание**: Token-bucket rate amount для обоих SYN limiter backends. Вместе с `synlimit_seconds` задает разрешенный source-IP SYN rate до того, как excess SYN packets начнут drop’аться.
|
||||
- **Пример**:
|
||||
|
||||
```toml
|
||||
[[server.listeners]]
|
||||
ip = "0.0.0.0"
|
||||
port = 443
|
||||
synlimit = "iptables"
|
||||
synlimit_hitcount = 1
|
||||
```
|
||||
## synlimit_burst (server.listeners)
|
||||
- **Ограничения / валидация**: `u32`, должно быть `> 0`. Значение по умолчанию: `2`.
|
||||
- **Описание**: Token-bucket burst size для обоих SYN limiter backends. Более высокие значения разрешают short connection bursts с одного source IP перед применением steady-state rate `synlimit_hitcount / synlimit_seconds`.
|
||||
- **Пример**:
|
||||
|
||||
```toml
|
||||
[[server.listeners]]
|
||||
ip = "0.0.0.0"
|
||||
port = 443
|
||||
synlimit = "iptables"
|
||||
synlimit_burst = 2
|
||||
```
|
||||
## announce
|
||||
- **Ограничения / валидация**: `String` (необязательный параметр). Не должен быть пустым, если задан.
|
||||
- **Описание**: Публичный IP-адрес или домен, объявляемый в proxy-ссылках для данного listener’а. Имеет приоритет над `announce_ip`.
|
||||
|
||||
@@ -313,6 +313,83 @@ mod tests {
|
||||
assert_eq!(err.code, "section_not_editable");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn patch_rejects_show_link_section() {
|
||||
// show_link is a legacy top-level scalar/array (not a [table]); it cannot
|
||||
// be upserted safely and is superseded by the editable general.links.show.
|
||||
let (path, _d) = temp_config("[censorship]\ntls_domain = \"a\"\n");
|
||||
let patch: Json = serde_json::json!({"show_link": "*"});
|
||||
let err = apply_patch_to_path(&path, &patch, None).await.unwrap_err();
|
||||
assert_eq!(err.code, "section_not_editable");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn patch_general_links_show_is_editable() {
|
||||
// The supported replacement path: edit show via the general.links sub-table.
|
||||
let (path, _d) = temp_config(
|
||||
"[general]\nprefer_ipv6 = false\n[general.links]\nshow = \"*\"\n\
|
||||
[censorship]\ntls_domain = \"a\"\n",
|
||||
);
|
||||
let patch: Json = serde_json::json!({"general": {"links": {"show": ["alice"]}}});
|
||||
let resp = apply_patch_to_path(&path, &patch, None).await.unwrap();
|
||||
assert!(resp.changed.iter().any(|c| c == "general"));
|
||||
let written = tokio::fs::read_to_string(&path).await.unwrap();
|
||||
let parsed: toml::Value = toml::from_str(&written).unwrap();
|
||||
assert_eq!(
|
||||
parsed["general"]["links"]["show"][0].as_str(),
|
||||
Some("alice"),
|
||||
"{written}"
|
||||
);
|
||||
// No leaked top-level [links]/[modes] and no duplicate sub-tables.
|
||||
assert_eq!(written.matches("[general.links]").count(), 1, "{written}");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn patch_links_public_port_written_as_integer_not_float_or_string() {
|
||||
// A JSON integer must land on disk as a bare TOML integer (443), never
|
||||
// 443.0 nor "443". The write re-renders from the typed config, so the
|
||||
// u16 field dictates the output format regardless of JSON quirks.
|
||||
let (path, _d) = temp_config("[general]\nprefer_ipv6 = false\n");
|
||||
let patch: Json = serde_json::json!({"general": {"links": {"public_port": 443}}});
|
||||
apply_patch_to_path(&path, &patch, None).await.unwrap();
|
||||
|
||||
let written = tokio::fs::read_to_string(&path).await.unwrap();
|
||||
assert!(written.contains("public_port = 443"), "{written}");
|
||||
assert!(
|
||||
!written.contains("443.0"),
|
||||
"must not be a float:\n{written}"
|
||||
);
|
||||
assert!(
|
||||
!written.contains("\"443\""),
|
||||
"must not be a string:\n{written}"
|
||||
);
|
||||
|
||||
let parsed: toml::Value = toml::from_str(&written).unwrap();
|
||||
assert_eq!(
|
||||
parsed["general"]["links"]["public_port"].as_integer(),
|
||||
Some(443),
|
||||
"{written}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn patch_links_public_port_rejects_float() {
|
||||
// 443.0 cannot deserialize into u16 -> rejected, not silently coerced.
|
||||
let (path, _d) = temp_config("[general]\nprefer_ipv6 = false\n");
|
||||
let patch: Json = serde_json::json!({"general": {"links": {"public_port": 443.0}}});
|
||||
let err = apply_patch_to_path(&path, &patch, None).await.unwrap_err();
|
||||
assert_eq!(err.status, hyper::StatusCode::BAD_REQUEST, "{:?}", err);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn patch_links_public_port_rejects_string() {
|
||||
// "443" is a string, not a u16 -> rejected.
|
||||
let (path, _d) = temp_config("[general]\nprefer_ipv6 = false\n");
|
||||
let patch: Json = serde_json::json!({"general": {"links": {"public_port": "443"}}});
|
||||
let err = apply_patch_to_path(&path, &patch, None).await.unwrap_err();
|
||||
assert_eq!(err.status, hyper::StatusCode::BAD_REQUEST, "{:?}", err);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn patch_empty_is_rejected() {
|
||||
let (path, _d) = temp_config("[censorship]\ntls_domain = \"a\"\n");
|
||||
|
||||
@@ -102,9 +102,14 @@ pub(super) async fn save_config_to_disk(
|
||||
/// Intentionally excluded (defense-in-depth, enforces the spec's per-node
|
||||
/// identity invariant at the Telemt layer too):
|
||||
///
|
||||
/// - `access` : owned by the users API.
|
||||
/// - `server` : carries per-node identity (`port`, `api`/`api_bind`, listeners).
|
||||
/// - `network` : carries per-node identity (`ipv4`/`ipv6`).
|
||||
/// - `access` : owned by the users API.
|
||||
/// - `server` : carries per-node identity (`port`, `api`/`api_bind`, listeners).
|
||||
/// - `network` : carries per-node identity (`ipv4`/`ipv6`).
|
||||
/// - `show_link` : legacy top-level scalar/array (not a `[table]`), superseded
|
||||
/// by the editable `general.links.show` sub-table. The
|
||||
/// section-upsert machinery here only handles `[table]` /
|
||||
/// `[[array-of-tables]]` blocks; a bare top-level key cannot be
|
||||
/// located or replaced safely, so it is edited via `general`.
|
||||
///
|
||||
/// A future field-level allowlist can re-admit specific safe fields
|
||||
/// (e.g. `network.dns_overrides`) without opening the whole section.
|
||||
@@ -113,7 +118,6 @@ pub(super) const EDITABLE_SECTIONS: &[&str] = &[
|
||||
"timeouts",
|
||||
"censorship",
|
||||
"upstreams",
|
||||
"show_link",
|
||||
"dc_overrides",
|
||||
];
|
||||
|
||||
@@ -162,10 +166,15 @@ fn render_top_level_section(cfg: &ProxyConfig, section: &str) -> Result<String,
|
||||
return Ok(out);
|
||||
}
|
||||
|
||||
let body = toml::to_string(table)
|
||||
// Serialize the table *inside a wrapper keyed by `section`* so the `toml`
|
||||
// crate emits correctly dotted headers for nested sub-tables, e.g.
|
||||
// `[general]` + `[general.modes]` + `[general.links]`. Serializing the
|
||||
// inner table alone would render bare `[modes]`/`[links]` headers, which
|
||||
// would leak as duplicate top-level tables and break config load.
|
||||
let mut wrapper = toml::value::Table::new();
|
||||
wrapper.insert(section.to_string(), table.clone());
|
||||
let mut out = toml::to_string(&toml::Value::Table(wrapper))
|
||||
.map_err(|e| ApiFailure::internal(format!("failed to serialize {}: {}", section, e)))?;
|
||||
let mut out = format!("[{}]\n", section);
|
||||
out.push_str(&body);
|
||||
if !out.ends_with('\n') {
|
||||
out.push('\n');
|
||||
}
|
||||
@@ -328,11 +337,22 @@ fn serialize_toml_key(key: &str) -> Result<String, ApiFailure> {
|
||||
}
|
||||
|
||||
fn upsert_toml_table(source: &str, table_name: &str, replacement: &str) -> String {
|
||||
if let Some((start, end)) = find_toml_table_bounds(source, table_name) {
|
||||
let blocks = find_all_table_blocks(source, table_name);
|
||||
if let Some(&(first_start, first_end)) = blocks.first() {
|
||||
// Replace the first block in place and delete any further blocks that
|
||||
// also belong to this table. Telemt writes a section's sub-tables
|
||||
// contiguously, but a hand-edited config may scatter them; dropping the
|
||||
// extras here prevents the duplicate-table corruption that would
|
||||
// otherwise break config load.
|
||||
let mut out = String::with_capacity(source.len() + replacement.len());
|
||||
out.push_str(&source[..start]);
|
||||
out.push_str(&source[..first_start]);
|
||||
out.push_str(replacement);
|
||||
out.push_str(&source[end..]);
|
||||
let mut cursor = first_end;
|
||||
for &(start, end) in &blocks[1..] {
|
||||
out.push_str(&source[cursor..start]);
|
||||
cursor = end;
|
||||
}
|
||||
out.push_str(&source[cursor..]);
|
||||
return out;
|
||||
}
|
||||
|
||||
@@ -347,29 +367,62 @@ fn upsert_toml_table(source: &str, table_name: &str, replacement: &str) -> Strin
|
||||
out
|
||||
}
|
||||
|
||||
/// Whether a (comment-stripped, trimmed) TOML header line belongs to
|
||||
/// `table_name`: the table itself (`[X]` / `[[X]]`) or any of its nested
|
||||
/// sub-tables (`[X.…]` / `[[X.…]]`). The trailing dot guards against sibling
|
||||
/// prefixes — `access.users` must not match `access.user_enabled`.
|
||||
fn header_belongs_to(header: &str, table_name: &str) -> bool {
|
||||
let body = match header.strip_prefix("[[").and_then(|h| h.strip_suffix("]]")) {
|
||||
Some(body) => body,
|
||||
None => match header.strip_prefix('[').and_then(|h| h.strip_suffix(']')) {
|
||||
Some(body) => body,
|
||||
None => return false,
|
||||
},
|
||||
};
|
||||
let body = body.trim();
|
||||
body == table_name
|
||||
|| body
|
||||
.strip_prefix(table_name)
|
||||
.is_some_and(|rest| rest.starts_with('.'))
|
||||
}
|
||||
|
||||
/// Locate the first contiguous byte range covering `table_name` and the nested
|
||||
/// sub-tables immediately following it. Used for existence checks; see
|
||||
/// [`find_all_table_blocks`] for the full set of (possibly scattered) blocks.
|
||||
fn find_toml_table_bounds(source: &str, table_name: &str) -> Option<(usize, usize)> {
|
||||
let single = format!("[{}]", table_name);
|
||||
let array = format!("[[{}]]", table_name);
|
||||
find_all_table_blocks(source, table_name).into_iter().next()
|
||||
}
|
||||
|
||||
/// Locate every byte range that belongs to `table_name`: the table header and
|
||||
/// its nested sub-tables. Returns one range per contiguous run, so a config
|
||||
/// where a section's sub-tables are scattered (e.g. hand-edited) yields several
|
||||
/// ranges — letting the caller collapse them into a single rendered block.
|
||||
fn find_all_table_blocks(source: &str, table_name: &str) -> Vec<(usize, usize)> {
|
||||
let mut blocks = Vec::new();
|
||||
let mut offset = 0usize;
|
||||
let mut start = None;
|
||||
let mut start: Option<usize> = None;
|
||||
|
||||
for line in source.split_inclusive('\n') {
|
||||
// Drop any inline comment so a hand-edited header like
|
||||
// `[censorship] # note` still matches. Section names never contain `#`.
|
||||
let header = line.trim().split('#').next().unwrap_or("").trim();
|
||||
let is_header = header.starts_with('[');
|
||||
if let Some(start_offset) = start {
|
||||
let is_same_array = header == array;
|
||||
let is_new_header = header.starts_with('[');
|
||||
if is_new_header && !is_same_array {
|
||||
return Some((start_offset, offset));
|
||||
if is_header && !header_belongs_to(header, table_name) {
|
||||
blocks.push((start_offset, offset));
|
||||
start = None;
|
||||
}
|
||||
} else if header == single || header == array {
|
||||
}
|
||||
if start.is_none() && header_belongs_to(header, table_name) {
|
||||
start = Some(offset);
|
||||
}
|
||||
offset = offset.saturating_add(line.len());
|
||||
}
|
||||
|
||||
start.map(|start_offset| (start_offset, source.len()))
|
||||
if let Some(start_offset) = start {
|
||||
blocks.push((start_offset, source.len()));
|
||||
}
|
||||
blocks
|
||||
}
|
||||
|
||||
async fn write_atomic(path: PathBuf, contents: String) -> Result<(), ApiFailure> {
|
||||
@@ -467,6 +520,138 @@ mod tests {
|
||||
assert!(!slice.contains("[server]")); // terminates at the next header
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn save_general_section_keeps_subtables_dotted_without_duplicates() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let path = dir.path().join("config.toml");
|
||||
tokio::fs::write(
|
||||
&path,
|
||||
"[general]\nprefer_ipv6 = false\n\n[general.modes]\ntls = true\n\n\
|
||||
[general.links]\npublic_host = \"old.example\"\n\n[server]\nport = 443\n",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.prefer_ipv6 = true;
|
||||
|
||||
save_sections_to_disk(&path, &cfg, &["general"])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let written = tokio::fs::read_to_string(&path).await.unwrap();
|
||||
|
||||
// No bare top-level [modes] / [links] headers leaked.
|
||||
for line in written.lines() {
|
||||
let header = line.trim();
|
||||
assert_ne!(header, "[modes]", "leaked top-level [modes]:\n{written}");
|
||||
assert_ne!(header, "[links]", "leaked top-level [links]:\n{written}");
|
||||
}
|
||||
|
||||
// Sub-tables kept their dotted prefix exactly once each.
|
||||
assert_eq!(
|
||||
written.matches("[general.modes]").count(),
|
||||
1,
|
||||
"[general.modes] must appear exactly once:\n{written}"
|
||||
);
|
||||
assert_eq!(
|
||||
written.matches("[general.links]").count(),
|
||||
1,
|
||||
"[general.links] must appear exactly once:\n{written}"
|
||||
);
|
||||
|
||||
// Result parses (duplicate tables would error here).
|
||||
toml::from_str::<toml::Value>(&written)
|
||||
.unwrap_or_else(|e| panic!("written config must parse: {e}\n{written}"));
|
||||
|
||||
assert!(written.contains("[server]\nport = 443")); // untouched table kept
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn save_general_section_is_idempotent_across_repeated_saves() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let path = dir.path().join("config.toml");
|
||||
tokio::fs::write(
|
||||
&path,
|
||||
"[general]\nprefer_ipv6 = false\n\n[general.modes]\ntls = true\n\n\
|
||||
[general.links]\npublic_host = \"old.example\"\n",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.prefer_ipv6 = true;
|
||||
|
||||
save_sections_to_disk(&path, &cfg, &["general"])
|
||||
.await
|
||||
.unwrap();
|
||||
save_sections_to_disk(&path, &cfg, &["general"])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let written = tokio::fs::read_to_string(&path).await.unwrap();
|
||||
assert_eq!(written.matches("[general.modes]").count(), 1, "{written}");
|
||||
assert_eq!(written.matches("[general.links]").count(), 1, "{written}");
|
||||
assert_eq!(written.matches("[general]").count(), 1, "{written}");
|
||||
toml::from_str::<toml::Value>(&written)
|
||||
.unwrap_or_else(|e| panic!("written config must parse: {e}\n{written}"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_bounds_spans_dotted_subtables() {
|
||||
let src = "[general]\nprefer_ipv6 = false\n\n[general.modes]\ntls = true\n\n\
|
||||
[general.links]\npublic_host = \"a\"\n\n[server]\nport = 1\n";
|
||||
let bounds = find_toml_table_bounds(src, "general");
|
||||
assert!(bounds.is_some(), "should locate [general] block");
|
||||
let (start, end) = bounds.unwrap();
|
||||
let slice = &src[start..end];
|
||||
assert!(slice.starts_with("[general]"));
|
||||
assert!(slice.contains("[general.modes]")); // spans nested sub-tables
|
||||
assert!(slice.contains("[general.links]"));
|
||||
assert!(!slice.contains("[server]")); // terminates at the next unrelated header
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_bounds_does_not_overrun_sibling_prefix() {
|
||||
// access.users must not swallow access.user_enabled (dot guards the prefix).
|
||||
let src = "[access.users]\nalice = \"x\"\n\n[access.user_enabled]\nalice = true\n";
|
||||
let bounds = find_toml_table_bounds(src, "access.users").unwrap();
|
||||
let slice = &src[bounds.0..bounds.1];
|
||||
assert!(slice.starts_with("[access.users]"));
|
||||
assert!(!slice.contains("[access.user_enabled]"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn save_general_handles_non_contiguous_subtables() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let path = dir.path().join("config.toml");
|
||||
// Hand-edited layout: [general.modes] sits AFTER an unrelated [server].
|
||||
tokio::fs::write(
|
||||
&path,
|
||||
"[general]\nprefer_ipv6 = false\n\n[server]\nport = 443\n\n\
|
||||
[general.modes]\ntls = true\n",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.prefer_ipv6 = true;
|
||||
|
||||
save_sections_to_disk(&path, &cfg, &["general"])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let written = tokio::fs::read_to_string(&path).await.unwrap();
|
||||
assert_eq!(
|
||||
written.matches("[general.modes]").count(),
|
||||
1,
|
||||
"non-contiguous [general.modes] must not duplicate:\n{written}"
|
||||
);
|
||||
toml::from_str::<toml::Value>(&written)
|
||||
.unwrap_or_else(|e| panic!("written config must parse: {e}\n{written}"));
|
||||
assert!(written.contains("[server]")); // unrelated section preserved
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn render_user_rate_limits_section() {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
|
||||
@@ -54,6 +54,9 @@ const DEFAULT_CONNTRACK_CONTROL_ENABLED: bool = true;
|
||||
const DEFAULT_CONNTRACK_PRESSURE_HIGH_WATERMARK_PCT: u8 = 85;
|
||||
const DEFAULT_CONNTRACK_PRESSURE_LOW_WATERMARK_PCT: u8 = 70;
|
||||
const DEFAULT_CONNTRACK_DELETE_BUDGET_PER_SEC: u64 = 4096;
|
||||
const DEFAULT_SYNLIMIT_SECONDS: u32 = 1;
|
||||
const DEFAULT_SYNLIMIT_HITCOUNT: u32 = 1;
|
||||
const DEFAULT_SYNLIMIT_BURST: u32 = 2;
|
||||
const DEFAULT_UPSTREAM_CONNECT_RETRY_ATTEMPTS: u32 = 2;
|
||||
const DEFAULT_UPSTREAM_UNHEALTHY_FAIL_THRESHOLD: u32 = 5;
|
||||
const DEFAULT_UPSTREAM_CONNECT_BUDGET_MS: u64 = 3000;
|
||||
@@ -243,6 +246,18 @@ pub(crate) fn default_conntrack_delete_budget_per_sec() -> u64 {
|
||||
DEFAULT_CONNTRACK_DELETE_BUDGET_PER_SEC
|
||||
}
|
||||
|
||||
pub(crate) fn default_synlimit_seconds() -> u32 {
|
||||
DEFAULT_SYNLIMIT_SECONDS
|
||||
}
|
||||
|
||||
pub(crate) fn default_synlimit_hitcount() -> u32 {
|
||||
DEFAULT_SYNLIMIT_HITCOUNT
|
||||
}
|
||||
|
||||
pub(crate) fn default_synlimit_burst() -> u32 {
|
||||
DEFAULT_SYNLIMIT_BURST
|
||||
}
|
||||
|
||||
pub(crate) fn default_prefer_4() -> u8 {
|
||||
4
|
||||
}
|
||||
|
||||
@@ -16,10 +16,12 @@
|
||||
//! | `general` | `telemetry` / `me_*_policy` | Applied immediately |
|
||||
//! | `network` | `dns_overrides` | Applied immediately |
|
||||
//! | `access` | All user/quota fields | Effective immediately |
|
||||
//! | `server.listeners` | `synlimit*` for existing endpoints | Netfilter rules reconciled immediately |
|
||||
//!
|
||||
//! Fields that require re-binding sockets (`server.listeners`, legacy
|
||||
//! `server.port`, `censorship.*`, `network.*`, `use_middle_proxy`) are **not**
|
||||
//! applied; a warning is emitted.
|
||||
//! applied, except for SYN limiter fields on unchanged listener endpoints; a
|
||||
//! warning is emitted.
|
||||
//! Non-hot changes are never mixed into the runtime config snapshot.
|
||||
|
||||
use std::collections::BTreeSet;
|
||||
@@ -34,7 +36,8 @@ use tracing::{error, info, warn};
|
||||
|
||||
use super::load::{LoadedConfig, ProxyConfig};
|
||||
use crate::config::{
|
||||
LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel, MeWriterPickMode,
|
||||
ListenerConfig, LogLevel, MeBindStaleMode, MeFloorMode, MeSocksKdfPolicy, MeTelemetryLevel,
|
||||
MeWriterPickMode, SynLimitMode,
|
||||
};
|
||||
|
||||
const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50);
|
||||
@@ -131,6 +134,17 @@ pub struct HotFields {
|
||||
pub user_max_unique_ips_global_each: usize,
|
||||
pub user_max_unique_ips_mode: crate::config::UserMaxUniqueIpsMode,
|
||||
pub user_max_unique_ips_window_secs: u64,
|
||||
pub listener_synlimit: Vec<ListenerSynLimitHotFields>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ListenerSynLimitHotFields {
|
||||
pub ip: IpAddr,
|
||||
pub port: Option<u16>,
|
||||
pub synlimit: SynLimitMode,
|
||||
pub synlimit_seconds: u32,
|
||||
pub synlimit_hitcount: u32,
|
||||
pub synlimit_burst: u32,
|
||||
}
|
||||
|
||||
impl HotFields {
|
||||
@@ -260,6 +274,25 @@ impl HotFields {
|
||||
user_max_unique_ips_global_each: cfg.access.user_max_unique_ips_global_each,
|
||||
user_max_unique_ips_mode: cfg.access.user_max_unique_ips_mode,
|
||||
user_max_unique_ips_window_secs: cfg.access.user_max_unique_ips_window_secs,
|
||||
listener_synlimit: cfg
|
||||
.server
|
||||
.listeners
|
||||
.iter()
|
||||
.map(ListenerSynLimitHotFields::from_listener)
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ListenerSynLimitHotFields {
|
||||
fn from_listener(listener: &ListenerConfig) -> Self {
|
||||
Self {
|
||||
ip: listener.ip,
|
||||
port: listener.port,
|
||||
synlimit: listener.synlimit,
|
||||
synlimit_seconds: listener.synlimit_seconds,
|
||||
synlimit_hitcount: listener.synlimit_hitcount,
|
||||
synlimit_burst: listener.synlimit_burst,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -566,6 +599,7 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
|
||||
cfg.access.user_max_unique_ips_global_each = new.access.user_max_unique_ips_global_each;
|
||||
cfg.access.user_max_unique_ips_mode = new.access.user_max_unique_ips_mode;
|
||||
cfg.access.user_max_unique_ips_window_secs = new.access.user_max_unique_ips_window_secs;
|
||||
overlay_listener_synlimit_fields(&mut cfg.server.listeners, &new.server.listeners);
|
||||
|
||||
if cfg.rebuild_runtime_user_auth().is_err() {
|
||||
cfg.runtime_user_auth = None;
|
||||
@@ -574,6 +608,21 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
|
||||
cfg
|
||||
}
|
||||
|
||||
fn overlay_listener_synlimit_fields(old: &mut [ListenerConfig], new: &[ListenerConfig]) {
|
||||
if old.len() != new.len() {
|
||||
return;
|
||||
}
|
||||
for (old_listener, new_listener) in old.iter_mut().zip(new.iter()) {
|
||||
if old_listener.ip != new_listener.ip || old_listener.port != new_listener.port {
|
||||
continue;
|
||||
}
|
||||
old_listener.synlimit = new_listener.synlimit;
|
||||
old_listener.synlimit_seconds = new_listener.synlimit_seconds;
|
||||
old_listener.synlimit_hitcount = new_listener.synlimit_hitcount;
|
||||
old_listener.synlimit_burst = new_listener.synlimit_burst;
|
||||
}
|
||||
}
|
||||
|
||||
/// Warn if any non-hot fields changed (require restart).
|
||||
fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: bool) {
|
||||
let mut warned = false;
|
||||
@@ -850,6 +899,13 @@ fn log_changes(
|
||||
);
|
||||
}
|
||||
|
||||
if old_hot.listener_synlimit != new_hot.listener_synlimit {
|
||||
info!(
|
||||
"config reload: server.listeners SYN limiter updated ({} listeners)",
|
||||
new_hot.listener_synlimit.len()
|
||||
);
|
||||
}
|
||||
|
||||
if old_hot.desync_all_full != new_hot.desync_all_full {
|
||||
info!(
|
||||
"config reload: desync_all_full: {} → {}",
|
||||
|
||||
@@ -346,6 +346,10 @@ const LISTENER_CONFIG_KEYS: &[&str] = &[
|
||||
"ip",
|
||||
"port",
|
||||
"client_mss",
|
||||
"synlimit",
|
||||
"synlimit_seconds",
|
||||
"synlimit_hitcount",
|
||||
"synlimit_burst",
|
||||
"announce",
|
||||
"announce_ip",
|
||||
"proxy_protocol",
|
||||
@@ -1948,6 +1952,21 @@ impl ProxyConfig {
|
||||
ProxyError::Config(format!("server.listeners[{idx}].client_mss {error}"))
|
||||
})?;
|
||||
}
|
||||
if listener.synlimit_seconds == 0 {
|
||||
return Err(ProxyError::Config(format!(
|
||||
"server.listeners[{idx}].synlimit_seconds must be > 0"
|
||||
)));
|
||||
}
|
||||
if listener.synlimit_hitcount == 0 {
|
||||
return Err(ProxyError::Config(format!(
|
||||
"server.listeners[{idx}].synlimit_hitcount must be > 0"
|
||||
)));
|
||||
}
|
||||
if listener.synlimit_burst == 0 {
|
||||
return Err(ProxyError::Config(format!(
|
||||
"server.listeners[{idx}].synlimit_burst must be > 0"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if config.server.accept_permit_timeout_ms > 60_000 {
|
||||
@@ -2186,6 +2205,10 @@ impl ProxyConfig {
|
||||
ip: ipv4,
|
||||
port: Some(config.server.port),
|
||||
client_mss: None,
|
||||
synlimit: SynLimitMode::default(),
|
||||
synlimit_seconds: default_synlimit_seconds(),
|
||||
synlimit_hitcount: default_synlimit_hitcount(),
|
||||
synlimit_burst: default_synlimit_burst(),
|
||||
announce: None,
|
||||
announce_ip: None,
|
||||
proxy_protocol: None,
|
||||
@@ -2199,6 +2222,10 @@ impl ProxyConfig {
|
||||
ip: ipv6,
|
||||
port: Some(config.server.port),
|
||||
client_mss: None,
|
||||
synlimit: SynLimitMode::default(),
|
||||
synlimit_seconds: default_synlimit_seconds(),
|
||||
synlimit_hitcount: default_synlimit_hitcount(),
|
||||
synlimit_burst: default_synlimit_burst(),
|
||||
announce: None,
|
||||
announce_ip: None,
|
||||
proxy_protocol: None,
|
||||
|
||||
@@ -429,7 +429,7 @@ pub struct GeneralConfig {
|
||||
pub ad_tag: Option<String>,
|
||||
|
||||
/// Public IP override for middle-proxy NAT environments.
|
||||
/// When set, this IP is used in ME key derivation and RPC_PROXY_REQ "our_addr".
|
||||
/// When set, this IP is used in ME key derivation and local address translation.
|
||||
#[serde(default)]
|
||||
pub middle_proxy_nat_ip: Option<IpAddr>,
|
||||
|
||||
@@ -1369,6 +1369,77 @@ impl ConntrackPressureProfile {
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-listener SYN limiter mode.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum SynLimitMode {
|
||||
/// Disable SYN limiting for this listener.
|
||||
#[default]
|
||||
Off,
|
||||
/// Use iptables/ip6tables filter rules with the hashlimit match.
|
||||
Iptables,
|
||||
/// Use nftables rules with per-source token-bucket meters.
|
||||
Nftables,
|
||||
}
|
||||
|
||||
impl Serialize for SynLimitMode {
|
||||
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match self {
|
||||
Self::Off => serializer.serialize_bool(false),
|
||||
Self::Iptables => serializer.serialize_str("iptables"),
|
||||
Self::Nftables => serializer.serialize_str("nftables"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for SynLimitMode {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct SynLimitModeVisitor;
|
||||
|
||||
impl<'de> serde::de::Visitor<'de> for SynLimitModeVisitor {
|
||||
type Value = SynLimitMode;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
formatter.write_str("false, iptables, or nftables")
|
||||
}
|
||||
|
||||
fn visit_bool<E>(self, value: bool) -> std::result::Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
if value {
|
||||
Err(E::custom(
|
||||
"synlimit=true is ambiguous; use \"iptables\" or \"nftables\"",
|
||||
))
|
||||
} else {
|
||||
Ok(SynLimitMode::Off)
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> std::result::Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
match value.trim().to_ascii_lowercase().as_str() {
|
||||
"false" | "off" | "disabled" | "none" => Ok(SynLimitMode::Off),
|
||||
"iptables" => Ok(SynLimitMode::Iptables),
|
||||
"nftables" => Ok(SynLimitMode::Nftables),
|
||||
_ => Err(E::custom(
|
||||
"synlimit must be false, \"iptables\", or \"nftables\"",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(SynLimitModeVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ConntrackControlConfig {
|
||||
/// Enables runtime conntrack-control worker for pressure mitigation.
|
||||
@@ -2102,6 +2173,18 @@ pub struct ListenerConfig {
|
||||
/// Empty string disables MSS shaping for this listener.
|
||||
#[serde(default)]
|
||||
pub client_mss: Option<String>,
|
||||
/// Per-listener SYN limiter mode.
|
||||
#[serde(default)]
|
||||
pub synlimit: SynLimitMode,
|
||||
/// Token-bucket rate interval for the per-listener SYN limiter.
|
||||
#[serde(default = "default_synlimit_seconds")]
|
||||
pub synlimit_seconds: u32,
|
||||
/// Token-bucket rate amount for the per-listener SYN limiter.
|
||||
#[serde(default = "default_synlimit_hitcount")]
|
||||
pub synlimit_hitcount: u32,
|
||||
/// Token-bucket burst size for the per-listener SYN limiter.
|
||||
#[serde(default = "default_synlimit_burst")]
|
||||
pub synlimit_burst: u32,
|
||||
/// IP address or hostname to announce in proxy links.
|
||||
/// Takes precedence over `announce_ip` if both are set.
|
||||
#[serde(default)]
|
||||
|
||||
@@ -208,6 +208,8 @@ pub(crate) async fn initialize_me_pool(
|
||||
me_nat_probe,
|
||||
None,
|
||||
config.network.stun_servers.clone(),
|
||||
config.network.stun_tcp_fallback,
|
||||
config.network.http_ip_detect_urls.clone(),
|
||||
config.general.stun_nat_probe_concurrency,
|
||||
probe.detected_ipv6,
|
||||
config.timeouts.me_one_retry,
|
||||
|
||||
@@ -45,6 +45,7 @@ use crate::stats::beobachten::BeobachtenStore;
|
||||
use crate::stats::telemetry::TelemetryPolicy;
|
||||
use crate::stats::{ReplayChecker, Stats};
|
||||
use crate::stream::BufferPool;
|
||||
use crate::synlimit_control;
|
||||
use crate::transport::UpstreamManager;
|
||||
use crate::transport::middle_proxy::MePool;
|
||||
use helpers::{
|
||||
@@ -909,6 +910,9 @@ async fn run_telemt_core(
|
||||
// On Unix, caller supplies privilege drop after bind (may require root for port < 1024).
|
||||
drop_after_bind();
|
||||
|
||||
synlimit_control::reconcile_synlimit_rules(&config).await;
|
||||
synlimit_control::spawn_synlimit_controller(config_rx.clone());
|
||||
|
||||
runtime_tasks::apply_runtime_log_filter(
|
||||
has_rust_log,
|
||||
&effective_log_level,
|
||||
|
||||
@@ -19,6 +19,7 @@ use tokio::signal::unix::{SignalKind, signal};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::stats::Stats;
|
||||
use crate::synlimit_control;
|
||||
use crate::transport::middle_proxy::MePool;
|
||||
|
||||
use super::helpers::{format_uptime, unit_label};
|
||||
@@ -102,6 +103,10 @@ async fn perform_shutdown(
|
||||
let uptime_secs = process_started_at.elapsed().as_secs();
|
||||
info!("Uptime: {}", format_uptime(uptime_secs));
|
||||
|
||||
if let Err(error) = synlimit_control::clear_synlimit_rules_all_backends().await {
|
||||
warn!(error = %error, "Failed to clear SYN limiter rules during shutdown");
|
||||
}
|
||||
|
||||
// Graceful ME pool shutdown
|
||||
if let Some(pool) = &me_pool {
|
||||
match tokio::time::timeout(Duration::from_secs(2), pool.shutdown_send_close_conn_all())
|
||||
|
||||
@@ -30,6 +30,7 @@ mod service;
|
||||
mod startup;
|
||||
mod stats;
|
||||
mod stream;
|
||||
mod synlimit_control;
|
||||
mod tls_front;
|
||||
mod transport;
|
||||
mod util;
|
||||
|
||||
@@ -12,7 +12,7 @@ use tracing::{debug, info, warn};
|
||||
use crate::config::{NetworkConfig, UpstreamConfig, UpstreamType};
|
||||
use crate::error::Result;
|
||||
use crate::network::stun::{
|
||||
DualStunResult, IpFamily, StunProbeResult, stun_probe_family_with_bind,
|
||||
DualStunResult, IpFamily, StunProbeResult, stun_probe_family_with_bind_and_tcp_fallback,
|
||||
};
|
||||
use crate::transport::UpstreamManager;
|
||||
|
||||
@@ -58,6 +58,7 @@ impl NetworkDecision {
|
||||
}
|
||||
|
||||
const STUN_BATCH_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const STUN_BATCH_TCP_FALLBACK_TIMEOUT: Duration = Duration::from_secs(12);
|
||||
|
||||
pub async fn run_probe(
|
||||
config: &NetworkConfig,
|
||||
@@ -81,8 +82,14 @@ pub async fn run_probe(
|
||||
warn!("STUN probe is enabled but network.stun_servers is empty");
|
||||
DualStunResult::default()
|
||||
} else {
|
||||
probe_stun_servers_parallel(&servers, stun_nat_probe_concurrency.max(1), None, None)
|
||||
.await
|
||||
probe_stun_servers_parallel(
|
||||
&servers,
|
||||
stun_nat_probe_concurrency.max(1),
|
||||
None,
|
||||
None,
|
||||
config.stun_tcp_fallback,
|
||||
)
|
||||
.await
|
||||
}
|
||||
} else if nat_probe {
|
||||
info!("STUN probe is disabled by network.stun_use=false");
|
||||
@@ -163,6 +170,7 @@ pub async fn run_probe(
|
||||
stun_nat_probe_concurrency.max(1),
|
||||
bind_v4,
|
||||
bind_v6,
|
||||
config.stun_tcp_fallback,
|
||||
)
|
||||
.await;
|
||||
if let Some(reflected) = direct_stun_res.v4.map(|r| r.reflected_addr) {
|
||||
@@ -234,7 +242,7 @@ pub async fn run_probe(
|
||||
Ok(probe)
|
||||
}
|
||||
|
||||
async fn detect_public_ipv4_http(urls: &[String]) -> Option<Ipv4Addr> {
|
||||
pub(crate) async fn detect_public_ipv4_http(urls: &[String]) -> Option<Ipv4Addr> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(3))
|
||||
.build()
|
||||
@@ -277,6 +285,7 @@ async fn probe_stun_servers_parallel(
|
||||
concurrency: usize,
|
||||
bind_v4: Option<IpAddr>,
|
||||
bind_v6: Option<IpAddr>,
|
||||
tcp_fallback: bool,
|
||||
) -> DualStunResult {
|
||||
let mut join_set = JoinSet::new();
|
||||
let mut next_idx = 0usize;
|
||||
@@ -288,9 +297,26 @@ async fn probe_stun_servers_parallel(
|
||||
let stun_addr = servers[next_idx].clone();
|
||||
next_idx += 1;
|
||||
join_set.spawn(async move {
|
||||
let res = timeout(STUN_BATCH_TIMEOUT, async {
|
||||
let v4 = stun_probe_family_with_bind(&stun_addr, IpFamily::V4, bind_v4).await?;
|
||||
let v6 = stun_probe_family_with_bind(&stun_addr, IpFamily::V6, bind_v6).await?;
|
||||
let batch_timeout = if tcp_fallback {
|
||||
STUN_BATCH_TCP_FALLBACK_TIMEOUT
|
||||
} else {
|
||||
STUN_BATCH_TIMEOUT
|
||||
};
|
||||
let res = timeout(batch_timeout, async {
|
||||
let v4 = stun_probe_family_with_bind_and_tcp_fallback(
|
||||
&stun_addr,
|
||||
IpFamily::V4,
|
||||
bind_v4,
|
||||
tcp_fallback,
|
||||
)
|
||||
.await?;
|
||||
let v6 = stun_probe_family_with_bind_and_tcp_fallback(
|
||||
&stun_addr,
|
||||
IpFamily::V6,
|
||||
bind_v6,
|
||||
tcp_fallback,
|
||||
)
|
||||
.await?;
|
||||
Ok::<DualStunResult, crate::error::ProxyError>(DualStunResult { v4, v6 })
|
||||
})
|
||||
.await;
|
||||
|
||||
@@ -4,7 +4,8 @@
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use tokio::net::{UdpSocket, lookup_host};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpSocket, UdpSocket, lookup_host};
|
||||
use tokio::time::{Duration, sleep, timeout};
|
||||
|
||||
use crate::crypto::SecureRandom;
|
||||
@@ -36,9 +37,16 @@ pub struct DualStunResult {
|
||||
}
|
||||
|
||||
pub async fn stun_probe_dual(stun_addr: &str) -> Result<DualStunResult> {
|
||||
stun_probe_dual_with_tcp_fallback(stun_addr, false).await
|
||||
}
|
||||
|
||||
pub async fn stun_probe_dual_with_tcp_fallback(
|
||||
stun_addr: &str,
|
||||
tcp_fallback: bool,
|
||||
) -> Result<DualStunResult> {
|
||||
let (v4, v6) = tokio::join!(
|
||||
stun_probe_family(stun_addr, IpFamily::V4),
|
||||
stun_probe_family(stun_addr, IpFamily::V6),
|
||||
stun_probe_family_with_tcp_fallback(stun_addr, IpFamily::V4, tcp_fallback),
|
||||
stun_probe_family_with_tcp_fallback(stun_addr, IpFamily::V6, tcp_fallback),
|
||||
);
|
||||
|
||||
Ok(DualStunResult { v4: v4?, v6: v6? })
|
||||
@@ -48,13 +56,44 @@ pub async fn stun_probe_family(
|
||||
stun_addr: &str,
|
||||
family: IpFamily,
|
||||
) -> Result<Option<StunProbeResult>> {
|
||||
stun_probe_family_with_bind(stun_addr, family, None).await
|
||||
stun_probe_family_with_tcp_fallback(stun_addr, family, false).await
|
||||
}
|
||||
|
||||
pub async fn stun_probe_family_with_tcp_fallback(
|
||||
stun_addr: &str,
|
||||
family: IpFamily,
|
||||
tcp_fallback: bool,
|
||||
) -> Result<Option<StunProbeResult>> {
|
||||
stun_probe_family_with_bind_and_tcp_fallback(stun_addr, family, None, tcp_fallback).await
|
||||
}
|
||||
|
||||
pub async fn stun_probe_family_with_bind(
|
||||
stun_addr: &str,
|
||||
family: IpFamily,
|
||||
bind_ip: Option<IpAddr>,
|
||||
) -> Result<Option<StunProbeResult>> {
|
||||
stun_probe_family_with_bind_and_tcp_fallback(stun_addr, family, bind_ip, false).await
|
||||
}
|
||||
|
||||
pub async fn stun_probe_family_with_bind_and_tcp_fallback(
|
||||
stun_addr: &str,
|
||||
family: IpFamily,
|
||||
bind_ip: Option<IpAddr>,
|
||||
tcp_fallback: bool,
|
||||
) -> Result<Option<StunProbeResult>> {
|
||||
let udp_attempts = if tcp_fallback { 1 } else { 3 };
|
||||
let udp_result = stun_probe_family_udp(stun_addr, family, bind_ip, udp_attempts).await?;
|
||||
if udp_result.is_some() || !tcp_fallback {
|
||||
return Ok(udp_result);
|
||||
}
|
||||
stun_probe_family_tcp(stun_addr, family, bind_ip).await
|
||||
}
|
||||
|
||||
async fn stun_probe_family_udp(
|
||||
stun_addr: &str,
|
||||
family: IpFamily,
|
||||
bind_ip: Option<IpAddr>,
|
||||
max_attempts: u8,
|
||||
) -> Result<Option<StunProbeResult>> {
|
||||
let bind_addr = match (family, bind_ip) {
|
||||
(IpFamily::V4, Some(IpAddr::V4(ip))) => SocketAddr::new(IpAddr::V4(ip), 0),
|
||||
@@ -94,12 +133,7 @@ pub async fn stun_probe_family_with_bind(
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut req = [0u8; 20];
|
||||
req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); // Binding Request
|
||||
req[2..4].copy_from_slice(&0u16.to_be_bytes()); // length
|
||||
req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); // magic cookie
|
||||
stun_rng().fill(&mut req[8..20]); // transaction ID
|
||||
|
||||
let req = build_binding_request();
|
||||
let mut buf = [0u8; 256];
|
||||
let mut attempt = 0;
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
@@ -115,7 +149,7 @@ pub async fn stun_probe_family_with_bind(
|
||||
Ok(Err(e)) => return Err(ProxyError::Proxy(format!("STUN recv failed: {e}"))),
|
||||
Err(_) => {
|
||||
attempt += 1;
|
||||
if attempt >= 3 {
|
||||
if attempt >= max_attempts {
|
||||
return Ok(None);
|
||||
}
|
||||
sleep(backoff).await;
|
||||
@@ -128,19 +162,139 @@ pub async fn stun_probe_family_with_bind(
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let magic = 0x2112A442u32.to_be_bytes();
|
||||
let txid = &req[8..20];
|
||||
let mut idx = 20;
|
||||
while idx + 4 <= n {
|
||||
let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap());
|
||||
let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize;
|
||||
idx += 4;
|
||||
if idx + alen > n {
|
||||
break;
|
||||
}
|
||||
if let Some(reflected_addr) = parse_reflected_addr(&buf[..n], txid) {
|
||||
let local_addr = socket
|
||||
.local_addr()
|
||||
.map_err(|e| ProxyError::Proxy(format!("STUN local_addr failed: {e}")))?;
|
||||
return Ok(Some(StunProbeResult {
|
||||
local_addr,
|
||||
reflected_addr,
|
||||
family,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
match atype {
|
||||
0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn stun_probe_family_tcp(
|
||||
stun_addr: &str,
|
||||
family: IpFamily,
|
||||
bind_ip: Option<IpAddr>,
|
||||
) -> Result<Option<StunProbeResult>> {
|
||||
let target_addr = match resolve_stun_addr(stun_addr, family).await? {
|
||||
Some(addr) => addr,
|
||||
None => return Ok(None),
|
||||
};
|
||||
let socket = match family {
|
||||
IpFamily::V4 => TcpSocket::new_v4(),
|
||||
IpFamily::V6 => TcpSocket::new_v6(),
|
||||
}
|
||||
.map_err(|e| ProxyError::Proxy(format!("STUN TCP socket failed: {e}")))?;
|
||||
match (family, bind_ip) {
|
||||
(IpFamily::V4, Some(IpAddr::V4(ip))) => {
|
||||
if socket.bind(SocketAddr::new(IpAddr::V4(ip), 0)).is_err() {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
(IpFamily::V6, Some(IpAddr::V6(ip))) => {
|
||||
if socket.bind(SocketAddr::new(IpAddr::V6(ip), 0)).is_err() {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
(IpFamily::V4, Some(IpAddr::V6(_))) | (IpFamily::V6, Some(IpAddr::V4(_))) => {
|
||||
return Ok(None);
|
||||
}
|
||||
(_, None) => {}
|
||||
}
|
||||
|
||||
let connect_res = timeout(Duration::from_secs(3), socket.connect(target_addr)).await;
|
||||
let mut stream = match connect_res {
|
||||
Ok(Ok(stream)) => stream,
|
||||
Ok(Err(e))
|
||||
if family == IpFamily::V6
|
||||
&& matches!(
|
||||
e.kind(),
|
||||
std::io::ErrorKind::NetworkUnreachable
|
||||
| std::io::ErrorKind::HostUnreachable
|
||||
| std::io::ErrorKind::Unsupported
|
||||
| std::io::ErrorKind::NetworkDown
|
||||
) =>
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Err(e)) => return Err(ProxyError::Proxy(format!("STUN TCP connect failed: {e}"))),
|
||||
Err(_) => return Ok(None),
|
||||
};
|
||||
|
||||
let req = build_binding_request();
|
||||
timeout(Duration::from_secs(3), stream.write_all(&req))
|
||||
.await
|
||||
.map_err(|_| ProxyError::Proxy("STUN TCP send timeout".to_string()))?
|
||||
.map_err(|e| ProxyError::Proxy(format!("STUN TCP send failed: {e}")))?;
|
||||
|
||||
let mut header = [0u8; 20];
|
||||
timeout(Duration::from_secs(3), stream.read_exact(&mut header))
|
||||
.await
|
||||
.map_err(|_| ProxyError::Proxy("STUN TCP header timeout".to_string()))?
|
||||
.map_err(|e| ProxyError::Proxy(format!("STUN TCP header read failed: {e}")))?;
|
||||
let body_len = u16::from_be_bytes([header[2], header[3]]) as usize;
|
||||
if body_len > 236 {
|
||||
return Ok(None);
|
||||
}
|
||||
let mut buf = [0u8; 256];
|
||||
buf[..20].copy_from_slice(&header);
|
||||
if body_len > 0 {
|
||||
timeout(
|
||||
Duration::from_secs(3),
|
||||
stream.read_exact(&mut buf[20..20 + body_len]),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| ProxyError::Proxy("STUN TCP body timeout".to_string()))?
|
||||
.map_err(|e| ProxyError::Proxy(format!("STUN TCP body read failed: {e}")))?;
|
||||
}
|
||||
|
||||
let txid = &req[8..20];
|
||||
let Some(reflected_addr) = parse_reflected_addr(&buf[..20 + body_len], txid) else {
|
||||
return Ok(None);
|
||||
};
|
||||
let local_addr = stream
|
||||
.local_addr()
|
||||
.map_err(|e| ProxyError::Proxy(format!("STUN TCP local_addr failed: {e}")))?;
|
||||
Ok(Some(StunProbeResult {
|
||||
local_addr,
|
||||
reflected_addr,
|
||||
family,
|
||||
}))
|
||||
}
|
||||
|
||||
fn build_binding_request() -> [u8; 20] {
|
||||
let mut req = [0u8; 20];
|
||||
req[0..2].copy_from_slice(&0x0001u16.to_be_bytes());
|
||||
req[2..4].copy_from_slice(&0u16.to_be_bytes());
|
||||
req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes());
|
||||
stun_rng().fill(&mut req[8..20]);
|
||||
req
|
||||
}
|
||||
|
||||
fn parse_reflected_addr(buf: &[u8], txid: &[u8]) -> Option<SocketAddr> {
|
||||
if buf.len() < 20 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let magic = 0x2112A442u32.to_be_bytes();
|
||||
let mut idx = 20;
|
||||
while idx + 4 <= buf.len() {
|
||||
let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().ok()?);
|
||||
let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().ok()?) as usize;
|
||||
idx += 4;
|
||||
if idx + alen > buf.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
match atype {
|
||||
0x0020 | 0x0001 => {
|
||||
if alen < 8 {
|
||||
break;
|
||||
}
|
||||
@@ -157,7 +311,6 @@ pub async fn stun_probe_family_with_bind(
|
||||
|
||||
let raw_ip = &buf[idx + 4..idx + 4 + len_check];
|
||||
let mut port = u16::from_be_bytes(port_bytes);
|
||||
|
||||
let reflected_ip = if atype == 0x0020 {
|
||||
port ^= ((magic[0] as u16) << 8) | magic[1] as u16;
|
||||
match family_byte {
|
||||
@@ -172,7 +325,9 @@ pub async fn stun_probe_family_with_bind(
|
||||
}
|
||||
0x02 => {
|
||||
let mut ip = [0u8; 16];
|
||||
let xor_key = [magic.as_slice(), txid].concat();
|
||||
let mut xor_key = [0u8; 16];
|
||||
xor_key[..4].copy_from_slice(&magic);
|
||||
xor_key[4..].copy_from_slice(txid.get(..12)?);
|
||||
for (i, b) in raw_ip.iter().enumerate().take(16) {
|
||||
ip[i] = *b ^ xor_key[i];
|
||||
}
|
||||
@@ -185,34 +340,24 @@ pub async fn stun_probe_family_with_bind(
|
||||
}
|
||||
} else {
|
||||
match family_byte {
|
||||
0x01 => IpAddr::V4(Ipv4Addr::new(raw_ip[0], raw_ip[1], raw_ip[2], raw_ip[3])),
|
||||
0x02 => IpAddr::V6(Ipv6Addr::from(<[u8; 16]>::try_from(raw_ip).unwrap())),
|
||||
0x01 => {
|
||||
IpAddr::V4(Ipv4Addr::new(raw_ip[0], raw_ip[1], raw_ip[2], raw_ip[3]))
|
||||
}
|
||||
0x02 => IpAddr::V6(Ipv6Addr::from(<[u8; 16]>::try_from(raw_ip).ok()?)),
|
||||
_ => {
|
||||
idx += (alen + 3) & !3;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let reflected_addr = SocketAddr::new(reflected_ip, port);
|
||||
let local_addr = socket
|
||||
.local_addr()
|
||||
.map_err(|e| ProxyError::Proxy(format!("STUN local_addr failed: {e}")))?;
|
||||
|
||||
return Ok(Some(StunProbeResult {
|
||||
local_addr,
|
||||
reflected_addr,
|
||||
family,
|
||||
}));
|
||||
return Some(SocketAddr::new(reflected_ip, port));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
idx += (alen + 3) & !3;
|
||||
}
|
||||
idx += (alen + 3) & !3;
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
None
|
||||
}
|
||||
|
||||
async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result<Option<SocketAddr>> {
|
||||
@@ -245,3 +390,58 @@ async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result<Option<S
|
||||
});
|
||||
Ok(target)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_reflected_addr_reads_mapped_ipv4() {
|
||||
let txid = [0u8; 12];
|
||||
let mut response = [0u8; 32];
|
||||
response[0..2].copy_from_slice(&0x0101u16.to_be_bytes());
|
||||
response[2..4].copy_from_slice(&12u16.to_be_bytes());
|
||||
response[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes());
|
||||
response[20..22].copy_from_slice(&0x0001u16.to_be_bytes());
|
||||
response[22..24].copy_from_slice(&8u16.to_be_bytes());
|
||||
response[25] = 0x01;
|
||||
response[26..28].copy_from_slice(&443u16.to_be_bytes());
|
||||
response[28..32].copy_from_slice(&[203, 0, 113, 9]);
|
||||
|
||||
let reflected = parse_reflected_addr(&response, &txid).unwrap();
|
||||
assert_eq!(
|
||||
reflected,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 9)), 443)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_reflected_addr_reads_xor_mapped_ipv4() {
|
||||
let txid = [0u8; 12];
|
||||
let magic = 0x2112A442u32.to_be_bytes();
|
||||
let port = 443u16;
|
||||
let ip = [203u8, 0, 113, 9];
|
||||
let xport = port ^ (((magic[0] as u16) << 8) | magic[1] as u16);
|
||||
let xip = [
|
||||
ip[0] ^ magic[0],
|
||||
ip[1] ^ magic[1],
|
||||
ip[2] ^ magic[2],
|
||||
ip[3] ^ magic[3],
|
||||
];
|
||||
let mut response = [0u8; 32];
|
||||
response[0..2].copy_from_slice(&0x0101u16.to_be_bytes());
|
||||
response[2..4].copy_from_slice(&12u16.to_be_bytes());
|
||||
response[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes());
|
||||
response[20..22].copy_from_slice(&0x0020u16.to_be_bytes());
|
||||
response[22..24].copy_from_slice(&8u16.to_be_bytes());
|
||||
response[25] = 0x01;
|
||||
response[26..28].copy_from_slice(&xport.to_be_bytes());
|
||||
response[28..32].copy_from_slice(&xip);
|
||||
|
||||
let reflected = parse_reflected_addr(&response, &txid).unwrap();
|
||||
assert_eq!(
|
||||
reflected,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 9)), 443)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,9 @@
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::protocol::framing::{
|
||||
secure_version_d_body_len_from_wire_len, secure_version_d_padding_len,
|
||||
};
|
||||
use std::sync::LazyLock;
|
||||
|
||||
// ============= Telegram Datacenters =============
|
||||
@@ -236,22 +239,20 @@ pub fn is_valid_secure_payload_len(data_len: usize) -> bool {
|
||||
}
|
||||
|
||||
/// Compute Secure Intermediate payload length from wire length.
|
||||
/// Secure mode strips up to 3 random tail bytes by truncating to 4-byte boundary.
|
||||
/// Secure mode cannot distinguish full-word padding from payload, so only the
|
||||
/// non-aligned tail bytes are stripped.
|
||||
pub fn secure_payload_len_from_wire_len(wire_len: usize) -> Option<usize> {
|
||||
if wire_len < 4 {
|
||||
return None;
|
||||
}
|
||||
Some(wire_len - (wire_len % 4))
|
||||
secure_version_d_body_len_from_wire_len(wire_len)
|
||||
}
|
||||
|
||||
/// Generate padding length for Secure Intermediate protocol.
|
||||
/// Data must be 4-byte aligned; padding is 1..=3 so total is never divisible by 4.
|
||||
/// Telegram Desktop uses a 4-bit random padding length for VersionD packets.
|
||||
pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize {
|
||||
debug_assert!(
|
||||
is_valid_secure_payload_len(data_len),
|
||||
"Secure payload must be 4-byte aligned, got {data_len}"
|
||||
);
|
||||
rng.range(3) + 1
|
||||
secure_version_d_padding_len(rng)
|
||||
}
|
||||
|
||||
// ============= Timeouts =============
|
||||
@@ -424,21 +425,15 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn secure_padding_never_produces_aligned_total() {
|
||||
fn secure_padding_matches_tdesktop_range() {
|
||||
let rng = SecureRandom::new();
|
||||
for data_len in (0..1000).step_by(4) {
|
||||
for _ in 0..100 {
|
||||
let padding = secure_padding_len(data_len, &rng);
|
||||
assert!(
|
||||
padding <= 3,
|
||||
padding <= 15,
|
||||
"padding out of range: data_len={data_len}, padding={padding}"
|
||||
);
|
||||
assert_ne!(
|
||||
(data_len + padding) % 4,
|
||||
0,
|
||||
"invariant violated: data_len={data_len}, padding={padding}, total={}",
|
||||
data_len + padding
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -454,6 +449,16 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn secure_wire_len_preserves_full_word_tail() {
|
||||
let payload_len = 64;
|
||||
for padding in [4usize, 8, 12] {
|
||||
let wire_len = payload_len + padding;
|
||||
let recovered = secure_payload_len_from_wire_len(wire_len);
|
||||
assert_eq!(recovered, Some(wire_len));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn secure_wire_len_rejects_too_short_frames() {
|
||||
assert_eq!(secure_payload_len_from_wire_len(0), None);
|
||||
|
||||
92
src/protocol/framing.rs
Normal file
92
src/protocol/framing.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
//! Shared MTProto transport framing helpers.
|
||||
|
||||
use crate::crypto::SecureRandom;
|
||||
|
||||
/// QuickACK marker bit used by Intermediate and Secure Intermediate headers.
|
||||
pub(crate) const INTERMEDIATE_QUICKACK_FLAG: u32 = 0x8000_0000;
|
||||
|
||||
/// Payload length mask used by Intermediate and Secure Intermediate headers.
|
||||
pub(crate) const INTERMEDIATE_WIRE_LEN_MASK: u32 = 0x7fff_ffff;
|
||||
|
||||
/// Maximum random tail length used by Telegram Desktop VersionD packets.
|
||||
pub(crate) const SECURE_VERSION_D_PADDING_MAX: usize = 15;
|
||||
|
||||
/// Parsed Intermediate/Secure Intermediate length header.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub(crate) struct IntermediateHeader {
|
||||
/// Payload length on the wire, excluding the four-byte header.
|
||||
pub(crate) wire_len: usize,
|
||||
/// Whether the QuickACK marker bit was set in the length header.
|
||||
pub(crate) quickack: bool,
|
||||
}
|
||||
|
||||
/// Parse an Intermediate/Secure Intermediate length header.
|
||||
pub(crate) fn parse_intermediate_header(header: [u8; 4]) -> IntermediateHeader {
|
||||
let raw = u32::from_le_bytes(header);
|
||||
IntermediateHeader {
|
||||
wire_len: (raw & INTERMEDIATE_WIRE_LEN_MASK) as usize,
|
||||
quickack: (raw & INTERMEDIATE_QUICKACK_FLAG) != 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode an Intermediate/Secure Intermediate length header.
|
||||
pub(crate) fn encode_intermediate_header(wire_len: usize, quickack: bool) -> Option<u32> {
|
||||
if wire_len > INTERMEDIATE_WIRE_LEN_MASK as usize {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut raw = u32::try_from(wire_len).ok()?;
|
||||
if quickack {
|
||||
raw |= INTERMEDIATE_QUICKACK_FLAG;
|
||||
}
|
||||
Some(raw)
|
||||
}
|
||||
|
||||
/// Recover the VersionD body length visible to MTProto from the encrypted wire length.
|
||||
pub(crate) fn secure_version_d_body_len_from_wire_len(wire_len: usize) -> Option<usize> {
|
||||
if wire_len < 4 {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(wire_len - (wire_len % 4))
|
||||
}
|
||||
|
||||
/// Generate Telegram Desktop-compatible VersionD random tail length.
|
||||
pub(crate) fn secure_version_d_padding_len(rng: &SecureRandom) -> usize {
|
||||
rng.range(SECURE_VERSION_D_PADDING_MAX + 1)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn intermediate_header_roundtrip_preserves_quickack_zero_length() {
|
||||
let encoded = encode_intermediate_header(0, true).unwrap();
|
||||
assert_eq!(encoded, INTERMEDIATE_QUICKACK_FLAG);
|
||||
|
||||
let parsed = parse_intermediate_header(encoded.to_le_bytes());
|
||||
assert_eq!(parsed.wire_len, 0);
|
||||
assert!(parsed.quickack);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn intermediate_header_rejects_lengths_above_31_bits() {
|
||||
assert_eq!(
|
||||
encode_intermediate_header(INTERMEDIATE_WIRE_LEN_MASK as usize, false),
|
||||
Some(INTERMEDIATE_WIRE_LEN_MASK)
|
||||
);
|
||||
assert_eq!(
|
||||
encode_intermediate_header(INTERMEDIATE_WIRE_LEN_MASK as usize + 1, false),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn secure_version_d_body_len_strips_only_non_word_tail() {
|
||||
assert_eq!(secure_version_d_body_len_from_wire_len(3), None);
|
||||
assert_eq!(secure_version_d_body_len_from_wire_len(8), Some(8));
|
||||
assert_eq!(secure_version_d_body_len_from_wire_len(11), Some(8));
|
||||
assert_eq!(secure_version_d_body_len_from_wire_len(12), Some(12));
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
pub mod constants;
|
||||
pub mod frame;
|
||||
pub(crate) mod framing;
|
||||
pub mod obfuscation;
|
||||
pub mod tls;
|
||||
pub mod tls_fingerprint;
|
||||
|
||||
@@ -638,15 +638,21 @@ fn build_server_hello_key_share_for_group(
|
||||
group: u16,
|
||||
rng: &SecureRandom,
|
||||
) -> Option<ServerHelloKeyShare> {
|
||||
let expected_key_exchange_len = client_hello_key_share_group_len(group)?;
|
||||
client_hello_key_share_group_entry(handshake, group, expected_key_exchange_len)?;
|
||||
|
||||
// FakeTLS clients validate ServerHello shape and digest, not TLS traffic
|
||||
// secrets, so the response must mirror the offered group without binding to
|
||||
// the camouflage key bytes embedded in ClientHello.
|
||||
match group {
|
||||
TLS_NAMED_GROUP_X25519MLKEM768 => {
|
||||
let key_exchange = build_x25519mlkem768_server_key_share(handshake, rng)?;
|
||||
Some(ServerHelloKeyShare::new(group, key_exchange))
|
||||
}
|
||||
TLS_NAMED_GROUP_X25519 => {
|
||||
let key_exchange = build_x25519_server_key_share(handshake, rng)?;
|
||||
Some(ServerHelloKeyShare::new(group, key_exchange))
|
||||
}
|
||||
TLS_NAMED_GROUP_X25519MLKEM768 => Some(ServerHelloKeyShare::new(
|
||||
group,
|
||||
gen_fake_x25519mlkem768_server_key_share(rng),
|
||||
)),
|
||||
TLS_NAMED_GROUP_X25519 => Some(ServerHelloKeyShare::new(
|
||||
group,
|
||||
gen_fake_x25519_key(rng).to_vec(),
|
||||
)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
|
||||
use dashmap::DashMap;
|
||||
use dashmap::mapref::entry::Entry;
|
||||
use hmac::{Hmac, Mac};
|
||||
#[cfg(test)]
|
||||
use std::collections::HashSet;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
@@ -33,8 +32,10 @@ use crate::stream::{CryptoReader, CryptoWriter, FakeTlsReader, FakeTlsWriter};
|
||||
use crate::tls_front::{TlsFrontCache, emulator};
|
||||
#[cfg(test)]
|
||||
use rand::RngExt;
|
||||
use sha2::Sha256;
|
||||
use subtle::ConstantTimeEq;
|
||||
|
||||
mod tls_auth;
|
||||
|
||||
use self::tls_auth::{parse_tls_auth_material, validate_tls_secret_candidate};
|
||||
|
||||
const ACCESS_SECRET_BYTES: usize = 16;
|
||||
const UNKNOWN_SNI_WARN_COOLDOWN_SECS: u64 = 5;
|
||||
@@ -58,8 +59,6 @@ const OVERLOAD_CANDIDATE_BUDGET_UNHINTED: usize = 8;
|
||||
const EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD: usize = 64;
|
||||
const RECENT_USER_RING_SCAN_LIMIT: usize = 32;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
#[cfg(test)]
|
||||
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1;
|
||||
#[cfg(not(test))]
|
||||
@@ -104,23 +103,6 @@ fn should_emit_unknown_sni_warn_in(shared: &ProxySharedState, now: Instant) -> b
|
||||
true
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct ParsedTlsAuthMaterial {
|
||||
digest: [u8; tls::TLS_DIGEST_LEN],
|
||||
session_id: [u8; 32],
|
||||
session_id_len: usize,
|
||||
now: i64,
|
||||
ignore_time_skew: bool,
|
||||
boot_time_cap_secs: u32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct TlsCandidateValidation {
|
||||
digest: [u8; tls::TLS_DIGEST_LEN],
|
||||
session_id: [u8; 32],
|
||||
session_id_len: usize,
|
||||
}
|
||||
|
||||
struct MtprotoCandidateValidation {
|
||||
proto_tag: ProtoTag,
|
||||
dc_idx: i16,
|
||||
@@ -251,104 +233,6 @@ fn budget_for_validation(total_users: usize, overload: bool, has_hint: bool) ->
|
||||
total_users.min(cap.max(1))
|
||||
}
|
||||
|
||||
fn parse_tls_auth_material(
|
||||
handshake: &[u8],
|
||||
ignore_time_skew: bool,
|
||||
replay_window_secs: u64,
|
||||
) -> Option<ParsedTlsAuthMaterial> {
|
||||
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let digest: [u8; tls::TLS_DIGEST_LEN] = handshake
|
||||
[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
|
||||
.try_into()
|
||||
.ok()?;
|
||||
|
||||
let session_id_len_pos = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN;
|
||||
let session_id_len = usize::from(handshake.get(session_id_len_pos).copied()?);
|
||||
if session_id_len > 32 {
|
||||
return None;
|
||||
}
|
||||
let session_id_start = session_id_len_pos + 1;
|
||||
if handshake.len() < session_id_start + session_id_len {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut session_id = [0u8; 32];
|
||||
session_id[..session_id_len]
|
||||
.copy_from_slice(&handshake[session_id_start..session_id_start + session_id_len]);
|
||||
|
||||
let now = if !ignore_time_skew {
|
||||
let d = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.ok()?;
|
||||
i64::try_from(d.as_secs()).ok()?
|
||||
} else {
|
||||
0_i64
|
||||
};
|
||||
|
||||
let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX);
|
||||
let boot_time_cap_secs = if ignore_time_skew {
|
||||
0
|
||||
} else {
|
||||
tls::BOOT_TIME_MAX_SECS
|
||||
.min(replay_window_u32)
|
||||
.min(tls::BOOT_TIME_COMPAT_MAX_SECS)
|
||||
};
|
||||
|
||||
Some(ParsedTlsAuthMaterial {
|
||||
digest,
|
||||
session_id,
|
||||
session_id_len,
|
||||
now,
|
||||
ignore_time_skew,
|
||||
boot_time_cap_secs,
|
||||
})
|
||||
}
|
||||
|
||||
fn compute_tls_hmac_zeroed_digest(secret: &[u8], handshake: &[u8]) -> [u8; 32] {
|
||||
let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC accepts any key length");
|
||||
mac.update(&handshake[..tls::TLS_DIGEST_POS]);
|
||||
mac.update(&[0u8; tls::TLS_DIGEST_LEN]);
|
||||
mac.update(&handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN..]);
|
||||
mac.finalize().into_bytes().into()
|
||||
}
|
||||
|
||||
fn validate_tls_secret_candidate(
|
||||
parsed: &ParsedTlsAuthMaterial,
|
||||
handshake: &[u8],
|
||||
secret: &[u8],
|
||||
) -> Option<TlsCandidateValidation> {
|
||||
let computed = compute_tls_hmac_zeroed_digest(secret, handshake);
|
||||
if !bool::from(parsed.digest[..28].ct_eq(&computed[..28])) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let timestamp = u32::from_le_bytes([
|
||||
parsed.digest[28] ^ computed[28],
|
||||
parsed.digest[29] ^ computed[29],
|
||||
parsed.digest[30] ^ computed[30],
|
||||
parsed.digest[31] ^ computed[31],
|
||||
]);
|
||||
|
||||
if !parsed.ignore_time_skew {
|
||||
let is_boot_time = parsed.boot_time_cap_secs > 0 && timestamp < parsed.boot_time_cap_secs;
|
||||
if !is_boot_time {
|
||||
let time_diff = parsed.now - i64::from(timestamp);
|
||||
if !(tls::TIME_SKEW_MIN..=tls::TIME_SKEW_MAX).contains(&time_diff) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(TlsCandidateValidation {
|
||||
digest: parsed.digest,
|
||||
session_id: parsed.session_id,
|
||||
session_id_len: parsed.session_id_len,
|
||||
})
|
||||
}
|
||||
|
||||
fn validate_mtproto_secret_candidate(
|
||||
handshake: &[u8; HANDSHAKE_LEN],
|
||||
dec_prekey: &[u8; PREKEY_LEN],
|
||||
@@ -1857,7 +1741,16 @@ where
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
|
||||
let validation = matched_validation.expect("validation must exist when matched");
|
||||
let Some(validation) = matched_validation else {
|
||||
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
warn!(
|
||||
peer = %peer,
|
||||
user = %matched_user,
|
||||
"MTProto handshake matched user without validation material"
|
||||
);
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
};
|
||||
|
||||
if config
|
||||
.access
|
||||
|
||||
126
src/proxy/handshake/tls_auth.rs
Normal file
126
src/proxy/handshake/tls_auth.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
use subtle::ConstantTimeEq;
|
||||
|
||||
use crate::protocol::tls;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
/// Parsed TLS authentication material extracted from a ClientHello candidate.
|
||||
#[derive(Clone, Copy)]
|
||||
pub(super) struct ParsedTlsAuthMaterial {
|
||||
digest: [u8; tls::TLS_DIGEST_LEN],
|
||||
session_id: [u8; 32],
|
||||
session_id_len: usize,
|
||||
now: i64,
|
||||
ignore_time_skew: bool,
|
||||
boot_time_cap_secs: u32,
|
||||
}
|
||||
|
||||
/// Successful TLS secret validation output used by the handshake state machine.
|
||||
#[derive(Clone, Copy)]
|
||||
pub(super) struct TlsCandidateValidation {
|
||||
pub(super) digest: [u8; tls::TLS_DIGEST_LEN],
|
||||
pub(super) session_id: [u8; 32],
|
||||
pub(super) session_id_len: usize,
|
||||
}
|
||||
|
||||
/// Parse TLS auth digest and session-id material from a candidate handshake.
|
||||
pub(super) fn parse_tls_auth_material(
|
||||
handshake: &[u8],
|
||||
ignore_time_skew: bool,
|
||||
replay_window_secs: u64,
|
||||
) -> Option<ParsedTlsAuthMaterial> {
|
||||
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let digest: [u8; tls::TLS_DIGEST_LEN] = handshake
|
||||
[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
|
||||
.try_into()
|
||||
.ok()?;
|
||||
|
||||
let session_id_len_pos = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN;
|
||||
let session_id_len = usize::from(handshake.get(session_id_len_pos).copied()?);
|
||||
if session_id_len > 32 {
|
||||
return None;
|
||||
}
|
||||
let session_id_start = session_id_len_pos + 1;
|
||||
if handshake.len() < session_id_start + session_id_len {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut session_id = [0u8; 32];
|
||||
session_id[..session_id_len]
|
||||
.copy_from_slice(&handshake[session_id_start..session_id_start + session_id_len]);
|
||||
|
||||
let now = if !ignore_time_skew {
|
||||
let d = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.ok()?;
|
||||
i64::try_from(d.as_secs()).ok()?
|
||||
} else {
|
||||
0_i64
|
||||
};
|
||||
|
||||
let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX);
|
||||
let boot_time_cap_secs = if ignore_time_skew {
|
||||
0
|
||||
} else {
|
||||
tls::BOOT_TIME_MAX_SECS
|
||||
.min(replay_window_u32)
|
||||
.min(tls::BOOT_TIME_COMPAT_MAX_SECS)
|
||||
};
|
||||
|
||||
Some(ParsedTlsAuthMaterial {
|
||||
digest,
|
||||
session_id,
|
||||
session_id_len,
|
||||
now,
|
||||
ignore_time_skew,
|
||||
boot_time_cap_secs,
|
||||
})
|
||||
}
|
||||
|
||||
fn compute_tls_hmac_zeroed_digest(secret: &[u8], handshake: &[u8]) -> Option<[u8; 32]> {
|
||||
let mut mac = HmacSha256::new_from_slice(secret).ok()?;
|
||||
mac.update(&handshake[..tls::TLS_DIGEST_POS]);
|
||||
mac.update(&[0u8; tls::TLS_DIGEST_LEN]);
|
||||
mac.update(&handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN..]);
|
||||
Some(mac.finalize().into_bytes().into())
|
||||
}
|
||||
|
||||
/// Validate a candidate secret against parsed TLS authentication material.
|
||||
pub(super) fn validate_tls_secret_candidate(
|
||||
parsed: &ParsedTlsAuthMaterial,
|
||||
handshake: &[u8],
|
||||
secret: &[u8],
|
||||
) -> Option<TlsCandidateValidation> {
|
||||
let computed = compute_tls_hmac_zeroed_digest(secret, handshake)?;
|
||||
if !bool::from(parsed.digest[..28].ct_eq(&computed[..28])) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let timestamp = u32::from_le_bytes([
|
||||
parsed.digest[28] ^ computed[28],
|
||||
parsed.digest[29] ^ computed[29],
|
||||
parsed.digest[30] ^ computed[30],
|
||||
parsed.digest[31] ^ computed[31],
|
||||
]);
|
||||
|
||||
if !parsed.ignore_time_skew {
|
||||
let is_boot_time = parsed.boot_time_cap_secs > 0 && timestamp < parsed.boot_time_cap_secs;
|
||||
if !is_boot_time {
|
||||
let time_diff = parsed.now - i64::from(timestamp);
|
||||
if !(tls::TIME_SKEW_MIN..=tls::TIME_SKEW_MAX).contains(&time_diff) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(TlsCandidateValidation {
|
||||
digest: parsed.digest,
|
||||
session_id: parsed.session_id,
|
||||
session_id_len: parsed.session_id_len,
|
||||
})
|
||||
}
|
||||
@@ -5,6 +5,7 @@ use crate::network::dns_overrides::resolve_socket_addr;
|
||||
use crate::protocol::tls;
|
||||
use crate::stats::beobachten::BeobachtenStore;
|
||||
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
|
||||
use crate::transport::socket::configure_tcp_socket;
|
||||
#[cfg(unix)]
|
||||
use nix::ifaddrs::getifaddrs;
|
||||
use rand::rngs::StdRng;
|
||||
@@ -36,6 +37,8 @@ const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200);
|
||||
#[cfg(test)]
|
||||
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100);
|
||||
const MASK_BUFFER_SIZE: usize = 8192;
|
||||
const MASK_BUFFER_GROW_AFTER_BYTES: usize = 256 * 1024;
|
||||
const MASK_BUFFER_MAX_SIZE: usize = 64 * 1024;
|
||||
#[cfg(unix)]
|
||||
#[cfg(not(test))]
|
||||
const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(300);
|
||||
@@ -53,6 +56,27 @@ struct MaskTcpTarget<'a> {
|
||||
port: u16,
|
||||
}
|
||||
|
||||
fn mask_copy_read_len(total: usize, byte_cap: usize) -> usize {
|
||||
// Keep short scanner probes on the small baseline buffer and grow only
|
||||
// after the session has proven to be sustained masking relay traffic.
|
||||
let active_buffer_size = if total >= MASK_BUFFER_GROW_AFTER_BYTES {
|
||||
MASK_BUFFER_MAX_SIZE
|
||||
} else {
|
||||
MASK_BUFFER_SIZE
|
||||
};
|
||||
|
||||
if byte_cap == 0 {
|
||||
return active_buffer_size;
|
||||
}
|
||||
|
||||
let remaining_budget = byte_cap.saturating_sub(total);
|
||||
if remaining_budget == 0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
remaining_budget.min(active_buffer_size)
|
||||
}
|
||||
|
||||
async fn copy_with_idle_timeout<R, W>(
|
||||
reader: &mut R,
|
||||
writer: &mut W,
|
||||
@@ -64,21 +88,18 @@ where
|
||||
R: AsyncRead + Unpin,
|
||||
W: AsyncWrite + Unpin,
|
||||
{
|
||||
let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]);
|
||||
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
|
||||
let mut total = 0usize;
|
||||
let mut ended_by_eof = false;
|
||||
let unlimited = byte_cap == 0;
|
||||
|
||||
loop {
|
||||
let read_len = if unlimited {
|
||||
MASK_BUFFER_SIZE
|
||||
} else {
|
||||
let remaining_budget = byte_cap.saturating_sub(total);
|
||||
if remaining_budget == 0 {
|
||||
break;
|
||||
}
|
||||
remaining_budget.min(MASK_BUFFER_SIZE)
|
||||
};
|
||||
let read_len = mask_copy_read_len(total, byte_cap);
|
||||
if read_len == 0 {
|
||||
break;
|
||||
}
|
||||
if buf.len() < read_len {
|
||||
buf.resize(read_len, 0);
|
||||
}
|
||||
let read_res = timeout(idle_timeout, reader.read(&mut buf[..read_len])).await;
|
||||
let n = match read_res {
|
||||
Ok(Ok(n)) => n,
|
||||
@@ -877,6 +898,12 @@ fn build_mask_proxy_header(
|
||||
}
|
||||
}
|
||||
|
||||
fn configure_mask_backend_socket(stream: &TcpStream) {
|
||||
if let Err(e) = configure_tcp_socket(stream, false, Duration::from_secs(0)) {
|
||||
debug!(error = %e, "Failed to configure mask backend socket");
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a bad client by forwarding to mask host
|
||||
pub async fn handle_bad_client<R, W>(
|
||||
reader: R,
|
||||
@@ -1047,6 +1074,7 @@ pub async fn handle_bad_client<R, W>(
|
||||
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
|
||||
match connect_result {
|
||||
Ok(Ok(stream)) => {
|
||||
configure_mask_backend_socket(&stream);
|
||||
let proxy_header =
|
||||
build_mask_proxy_header(config.censorship.mask_proxy_protocol, peer, local_addr);
|
||||
|
||||
@@ -1190,20 +1218,17 @@ async fn consume_client_data<R: AsyncRead + Unpin>(
|
||||
idle_timeout: Duration,
|
||||
) {
|
||||
// Keep drain path fail-closed under slow-loris stalls.
|
||||
let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]);
|
||||
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
|
||||
let mut total = 0usize;
|
||||
let unlimited = byte_cap == 0;
|
||||
|
||||
loop {
|
||||
let read_len = if unlimited {
|
||||
MASK_BUFFER_SIZE
|
||||
} else {
|
||||
let remaining_budget = byte_cap.saturating_sub(total);
|
||||
if remaining_budget == 0 {
|
||||
break;
|
||||
}
|
||||
remaining_budget.min(MASK_BUFFER_SIZE)
|
||||
};
|
||||
let read_len = mask_copy_read_len(total, byte_cap);
|
||||
if read_len == 0 {
|
||||
break;
|
||||
}
|
||||
if buf.len() < read_len {
|
||||
buf.resize(read_len, 0);
|
||||
}
|
||||
let n = match timeout(idle_timeout, reader.read(&mut buf[..read_len])).await {
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(_)) | Err(_) => break,
|
||||
@@ -1214,7 +1239,7 @@ async fn consume_client_data<R: AsyncRead + Unpin>(
|
||||
}
|
||||
|
||||
total = total.saturating_add(n);
|
||||
if !unlimited && total >= byte_cap {
|
||||
if byte_cap != 0 && total >= byte_cap {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -1332,6 +1357,10 @@ mod masking_interface_cache_concurrency_security_tests;
|
||||
#[path = "tests/masking_production_cap_regression_security_tests.rs"]
|
||||
mod masking_production_cap_regression_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_relay_manual_perf_tests.rs"]
|
||||
mod masking_relay_manual_perf_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_extended_attack_surface_security_tests.rs"]
|
||||
mod masking_extended_attack_surface_security_tests;
|
||||
|
||||
@@ -276,20 +276,17 @@ pub(in crate::proxy::middle_relay) fn compute_intermediate_secure_wire_len(
|
||||
let wire_len = data_len
|
||||
.checked_add(padding_len)
|
||||
.ok_or_else(|| ProxyError::Proxy("Frame length overflow".into()))?;
|
||||
if wire_len > 0x7fff_ffffusize {
|
||||
return Err(ProxyError::Proxy(format!(
|
||||
"Intermediate/Secure frame too large: {wire_len}"
|
||||
)));
|
||||
}
|
||||
|
||||
let len_val =
|
||||
crate::protocol::framing::encode_intermediate_header(wire_len, quickack).ok_or_else(
|
||||
|| {
|
||||
ProxyError::Proxy(format!(
|
||||
"Intermediate/Secure frame too large: {wire_len}"
|
||||
))
|
||||
},
|
||||
)?;
|
||||
let total = 4usize
|
||||
.checked_add(wire_len)
|
||||
.ok_or_else(|| ProxyError::Proxy("Frame buffer size overflow".into()))?;
|
||||
let mut len_val = u32::try_from(wire_len)
|
||||
.map_err(|_| ProxyError::Proxy("Frame length conversion overflow".into()))?;
|
||||
if quickack {
|
||||
len_val |= 0x8000_0000;
|
||||
}
|
||||
Ok((len_val, total))
|
||||
}
|
||||
|
||||
|
||||
@@ -236,10 +236,10 @@ where
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
let quickack = (len_buf[3] & 0x80) != 0;
|
||||
let header = crate::protocol::framing::parse_intermediate_header(len_buf);
|
||||
(
|
||||
(u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize,
|
||||
quickack,
|
||||
header.wire_len,
|
||||
header.quickack,
|
||||
Some(len_buf),
|
||||
)
|
||||
}
|
||||
@@ -331,7 +331,8 @@ where
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Secure Intermediate: strip validated trailing padding bytes.
|
||||
// Secure Intermediate strips only non-aligned tail padding; full-word
|
||||
// padding is indistinguishable from payload in VersionD framing.
|
||||
if proto_tag == ProtoTag::Secure {
|
||||
payload.truncate(secure_payload_len);
|
||||
}
|
||||
|
||||
111
src/proxy/tests/masking_relay_manual_perf_tests.rs
Normal file
111
src/proxy/tests/masking_relay_manual_perf_tests.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
use super::*;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::time::{Duration, Instant};
|
||||
|
||||
const PERF_TOTAL_BYTES: usize = 64 * 1024 * 1024;
|
||||
|
||||
struct PatternReader {
|
||||
remaining: usize,
|
||||
chunk: usize,
|
||||
read_calls: AtomicUsize,
|
||||
}
|
||||
|
||||
impl PatternReader {
|
||||
fn new(total: usize, chunk: usize) -> Self {
|
||||
Self {
|
||||
remaining: total,
|
||||
chunk,
|
||||
read_calls: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
fn read_calls(&self) -> usize {
|
||||
self.read_calls.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for PatternReader {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
self.read_calls.fetch_add(1, Ordering::Relaxed);
|
||||
if self.remaining == 0 {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
let take = self.remaining.min(self.chunk).min(buf.remaining());
|
||||
if take == 0 {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
static PATTERN: [u8; MASK_BUFFER_MAX_SIZE] = [0xA5; MASK_BUFFER_MAX_SIZE];
|
||||
buf.put_slice(&PATTERN[..take]);
|
||||
self.remaining -= take;
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct CountingWriter {
|
||||
written: usize,
|
||||
}
|
||||
|
||||
impl AsyncWrite for CountingWriter {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.written = self.written.saturating_add(buf.len());
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "manual benchmark: throughput-sensitive and host-dependent"]
|
||||
async fn masking_copy_with_idle_timeout_manual_throughput() {
|
||||
let mut reader = PatternReader::new(PERF_TOTAL_BYTES, MASK_BUFFER_MAX_SIZE);
|
||||
let mut writer = CountingWriter::default();
|
||||
let started = Instant::now();
|
||||
|
||||
let outcome = copy_with_idle_timeout(
|
||||
&mut reader,
|
||||
&mut writer,
|
||||
PERF_TOTAL_BYTES,
|
||||
true,
|
||||
Duration::from_secs(30),
|
||||
)
|
||||
.await;
|
||||
|
||||
let elapsed = started.elapsed();
|
||||
let mb = PERF_TOTAL_BYTES as f64 / (1024.0 * 1024.0);
|
||||
let mbps = mb / elapsed.as_secs_f64();
|
||||
|
||||
assert_eq!(outcome.total, PERF_TOTAL_BYTES);
|
||||
assert_eq!(writer.written, PERF_TOTAL_BYTES);
|
||||
assert!(
|
||||
!outcome.ended_by_eof,
|
||||
"manual throughput run should terminate at byte cap"
|
||||
);
|
||||
|
||||
eprintln!(
|
||||
"masking manual throughput: bytes={} elapsed_ms={} mib_per_sec={:.2} read_calls={}",
|
||||
PERF_TOTAL_BYTES,
|
||||
elapsed.as_millis(),
|
||||
mbps,
|
||||
reader.read_calls()
|
||||
);
|
||||
}
|
||||
@@ -15,6 +15,7 @@ use crate::crypto::SecureRandom;
|
||||
use crate::protocol::constants::{
|
||||
ProtoTag, is_valid_secure_payload_len, secure_padding_len, secure_payload_len_from_wire_len,
|
||||
};
|
||||
use crate::protocol::framing::{encode_intermediate_header, parse_intermediate_header};
|
||||
|
||||
// ============= Unified Codec =============
|
||||
|
||||
@@ -197,13 +198,9 @@ fn decode_intermediate(src: &mut BytesMut, max_size: usize) -> io::Result<Option
|
||||
}
|
||||
|
||||
let mut meta = FrameMeta::new();
|
||||
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
|
||||
|
||||
// Check QuickACK flag
|
||||
if len >= 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
let header = parse_intermediate_header([src[0], src[1], src[2], src[3]]);
|
||||
let len = header.wire_len;
|
||||
meta.quickack = header.quickack;
|
||||
|
||||
// Validate size
|
||||
if len > max_size {
|
||||
@@ -239,10 +236,12 @@ fn encode_intermediate(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
||||
|
||||
dst.reserve(4 + data.len());
|
||||
|
||||
let mut len = data.len() as u32;
|
||||
if frame.meta.quickack {
|
||||
len |= 0x80000000;
|
||||
}
|
||||
let len = encode_intermediate_header(data.len(), frame.meta.quickack).ok_or_else(|| {
|
||||
Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("frame too large: {} bytes", data.len()),
|
||||
)
|
||||
})?;
|
||||
|
||||
dst.extend_from_slice(&len.to_le_bytes());
|
||||
dst.extend_from_slice(data);
|
||||
@@ -258,13 +257,9 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
|
||||
}
|
||||
|
||||
let mut meta = FrameMeta::new();
|
||||
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
|
||||
|
||||
// Check QuickACK flag
|
||||
if len >= 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
let header = parse_intermediate_header([src[0], src[1], src[2], src[3]]);
|
||||
let len = header.wire_len;
|
||||
meta.quickack = header.quickack;
|
||||
|
||||
// Validate size
|
||||
if len > max_size {
|
||||
@@ -317,16 +312,18 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::R
|
||||
));
|
||||
}
|
||||
|
||||
// Generate padding that keeps total length non-divisible by 4.
|
||||
// Telegram Desktop VersionD uses a 4-bit random padding length.
|
||||
let padding_len = secure_padding_len(data.len(), rng);
|
||||
|
||||
let total_len = data.len() + padding_len;
|
||||
dst.reserve(4 + total_len);
|
||||
|
||||
let mut len = total_len as u32;
|
||||
if frame.meta.quickack {
|
||||
len |= 0x80000000;
|
||||
}
|
||||
let len = encode_intermediate_header(total_len, frame.meta.quickack).ok_or_else(|| {
|
||||
Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("frame too large: {} bytes", total_len),
|
||||
)
|
||||
})?;
|
||||
|
||||
dst.extend_from_slice(&len.to_le_bytes());
|
||||
dst.extend_from_slice(data);
|
||||
@@ -523,6 +520,16 @@ mod tests {
|
||||
use tokio::io::duplex;
|
||||
use tokio_util::codec::{FramedRead, FramedWrite};
|
||||
|
||||
fn assert_secure_decoded_payload(decoded: &[u8], original: &[u8]) {
|
||||
assert!(decoded.starts_with(original));
|
||||
assert!(
|
||||
(original.len()..=original.len() + 12).contains(&decoded.len()),
|
||||
"Secure decoded payload may retain up to 12 bytes of full-word padding, got {}",
|
||||
decoded.len()
|
||||
);
|
||||
assert_eq!(decoded.len() % 4, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_framed_abridged() {
|
||||
let (client, server) = duplex(4096);
|
||||
@@ -565,7 +572,7 @@ mod tests {
|
||||
writer.send(frame).await.unwrap();
|
||||
|
||||
let received = reader.next().await.unwrap().unwrap();
|
||||
assert_eq!(&received.data[..], &original[..]);
|
||||
assert_secure_decoded_payload(&received.data, &original);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -588,7 +595,11 @@ mod tests {
|
||||
writer.send(frame).await.unwrap();
|
||||
|
||||
let received = reader.next().await.unwrap().unwrap();
|
||||
assert_eq!(received.data.len(), 8);
|
||||
if proto_tag == ProtoTag::Secure {
|
||||
assert_secure_decoded_payload(&received.data, &original);
|
||||
} else {
|
||||
assert_eq!(received.data.len(), original.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -642,7 +653,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn secure_codec_always_adds_padding_and_jitters_wire_length() {
|
||||
fn secure_codec_uses_tdesktop_padding_range_and_jitters_wire_length() {
|
||||
let codec = SecureCodec::new(Arc::new(SecureRandom::new()));
|
||||
let payload = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
let mut wire_lens = HashSet::new();
|
||||
@@ -652,13 +663,12 @@ mod tests {
|
||||
let mut out = BytesMut::new();
|
||||
codec.encode(&frame, &mut out).unwrap();
|
||||
|
||||
assert!(out.len() > 4 + payload.len());
|
||||
let wire_len = u32::from_le_bytes([out[0], out[1], out[2], out[3]]) as usize;
|
||||
assert_eq!(out.len(), 4 + wire_len);
|
||||
assert!(
|
||||
(payload.len() + 1..=payload.len() + 3).contains(&wire_len),
|
||||
"Secure wire length must be payload+1..3, got {wire_len}"
|
||||
(payload.len()..=payload.len() + 15).contains(&wire_len),
|
||||
"Secure wire length must be payload+0..15, got {wire_len}"
|
||||
);
|
||||
assert_ne!(wire_len % 4, 0, "Secure wire length must be non-4-aligned");
|
||||
wire_lens.insert(wire_len);
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
use super::traits::{FrameMeta, LayeredStream};
|
||||
use crate::crypto::{SecureRandom, crc32};
|
||||
use crate::protocol::constants::*;
|
||||
use crate::protocol::framing::{encode_intermediate_header, parse_intermediate_header};
|
||||
use bytes::Bytes;
|
||||
use std::io::{Error, ErrorKind, Result};
|
||||
use std::sync::Arc;
|
||||
@@ -105,10 +106,17 @@ impl<W: AsyncWrite + Unpin> AbridgedFrameWriter<W> {
|
||||
|
||||
if len_div_4 < 0x7f {
|
||||
// Short length (1 byte)
|
||||
self.upstream.write_all(&[len_div_4 as u8]).await?;
|
||||
let mut first = len_div_4 as u8;
|
||||
if meta.quickack {
|
||||
first |= 0x80;
|
||||
}
|
||||
self.upstream.write_all(&[first]).await?;
|
||||
} else if len_div_4 < (1 << 24) {
|
||||
// Long length (4 bytes: 0x7f + 3 bytes)
|
||||
let mut header = [0x7f, 0, 0, 0];
|
||||
if meta.quickack {
|
||||
header[0] |= 0x80;
|
||||
}
|
||||
header[1..4].copy_from_slice(&(len_div_4 as u32).to_le_bytes()[..3]);
|
||||
self.upstream.write_all(&header).await?;
|
||||
} else {
|
||||
@@ -160,13 +168,9 @@ impl<R: AsyncRead + Unpin> IntermediateFrameReader<R> {
|
||||
let mut len_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut len_bytes).await?;
|
||||
|
||||
let mut len = u32::from_le_bytes(len_bytes) as usize;
|
||||
|
||||
// Check QuickACK flag (high bit)
|
||||
if len > 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
let header = parse_intermediate_header(len_bytes);
|
||||
let len = header.wire_len;
|
||||
meta.quickack = header.quickack;
|
||||
|
||||
// Read data
|
||||
let mut data = vec![0u8; len];
|
||||
@@ -204,7 +208,13 @@ impl<W: AsyncWrite + Unpin> IntermediateFrameWriter<W> {
|
||||
if meta.simple_ack {
|
||||
self.upstream.write_all(data).await?;
|
||||
} else {
|
||||
let len_bytes = (data.len() as u32).to_le_bytes();
|
||||
let len = encode_intermediate_header(data.len(), meta.quickack).ok_or_else(|| {
|
||||
Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("Frame too large: {} bytes", data.len()),
|
||||
)
|
||||
})?;
|
||||
let len_bytes = len.to_le_bytes();
|
||||
self.upstream.write_all(&len_bytes).await?;
|
||||
self.upstream.write_all(data).await?;
|
||||
}
|
||||
@@ -249,13 +259,9 @@ impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
|
||||
let mut len_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut len_bytes).await?;
|
||||
|
||||
let mut len = u32::from_le_bytes(len_bytes) as usize;
|
||||
|
||||
// Check QuickACK flag
|
||||
if len > 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
let header = parse_intermediate_header(len_bytes);
|
||||
let len = header.wire_len;
|
||||
meta.quickack = header.quickack;
|
||||
|
||||
// Read data (including padding)
|
||||
let mut data = vec![0u8; len];
|
||||
@@ -311,12 +317,18 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
|
||||
));
|
||||
}
|
||||
|
||||
// Add padding so total length is never divisible by 4 (MTProto Secure)
|
||||
// Telegram Desktop VersionD uses a 4-bit random padding length.
|
||||
let padding_len = secure_padding_len(data.len(), &self.rng);
|
||||
let padding = self.rng.bytes(padding_len);
|
||||
|
||||
let total_len = data.len() + padding_len;
|
||||
let len_bytes = (total_len as u32).to_le_bytes();
|
||||
let len = encode_intermediate_header(total_len, meta.quickack).ok_or_else(|| {
|
||||
Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("Frame too large: {total_len} bytes"),
|
||||
)
|
||||
})?;
|
||||
let len_bytes = len.to_le_bytes();
|
||||
|
||||
self.upstream.write_all(&len_bytes).await?;
|
||||
self.upstream.write_all(data).await?;
|
||||
@@ -559,6 +571,16 @@ mod tests {
|
||||
use std::sync::Arc;
|
||||
use tokio::io::duplex;
|
||||
|
||||
fn assert_secure_decoded_payload(decoded: &[u8], original: &[u8]) {
|
||||
assert!(decoded.starts_with(original));
|
||||
assert!(
|
||||
(original.len()..=original.len() + 12).contains(&decoded.len()),
|
||||
"Secure decoded payload may retain up to 12 bytes of full-word padding, got {}",
|
||||
decoded.len()
|
||||
);
|
||||
assert_eq!(decoded.len() % 4, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_abridged_roundtrip() {
|
||||
let (client, server) = duplex(1024);
|
||||
@@ -613,6 +635,43 @@ mod tests {
|
||||
assert_eq!(&received[..], &data[..]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_intermediate_quickack_zero_length_roundtrip() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = IntermediateFrameWriter::new(client);
|
||||
let mut reader = IntermediateFrameReader::new(server);
|
||||
|
||||
writer
|
||||
.write_frame(&[], &FrameMeta::new().with_quickack())
|
||||
.await
|
||||
.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let (received, meta) = reader.read_frame().await.unwrap();
|
||||
assert!(received.is_empty());
|
||||
assert!(meta.quickack);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_abridged_quickack_roundtrip() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = AbridgedFrameWriter::new(client);
|
||||
let mut reader = AbridgedFrameReader::new(server);
|
||||
|
||||
let data = vec![1u8, 2, 3, 4];
|
||||
writer
|
||||
.write_frame(&data, &FrameMeta::new().with_quickack())
|
||||
.await
|
||||
.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let (received, meta) = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &data[..]);
|
||||
assert!(meta.quickack);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_secure_intermediate_padding() {
|
||||
let (client, server) = duplex(1024);
|
||||
@@ -625,7 +684,7 @@ mod tests {
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let (received, _meta) = reader.read_frame().await.unwrap();
|
||||
assert_eq!(received.len(), data.len());
|
||||
assert_secure_decoded_payload(&received, &data);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
601
src/synlimit_control.rs
Normal file
601
src/synlimit_control.rs
Normal file
@@ -0,0 +1,601 @@
|
||||
use std::collections::BTreeSet;
|
||||
use std::net::IpAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::watch;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::config::{ProxyConfig, SynLimitMode};
|
||||
|
||||
const IPTABLES_CHAIN: &str = "TELEMT_SYNLIMIT";
|
||||
const IPTABLES_HASHLIMIT_NAME: &str = "TELEMT-BUMPER";
|
||||
const NFT_TABLE: &str = "telemt_synlimit";
|
||||
const NFT_CHAIN: &str = "input";
|
||||
type SynLimitTarget = (Option<IpAddr>, u16, u32, u32, u32);
|
||||
|
||||
#[derive(Default)]
|
||||
struct SynLimitTargets {
|
||||
iptables_v4: Vec<SynLimitTarget>,
|
||||
iptables_v6: Vec<SynLimitTarget>,
|
||||
nft_v4: Vec<SynLimitTarget>,
|
||||
nft_v6: Vec<SynLimitTarget>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct NftTableFamilies {
|
||||
inet: bool,
|
||||
ip: bool,
|
||||
ip6: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum NftFamily {
|
||||
Inet,
|
||||
Ip,
|
||||
Ip6,
|
||||
}
|
||||
|
||||
struct NftApplyPlan<'a> {
|
||||
family: NftFamily,
|
||||
v4_targets: &'a [SynLimitTarget],
|
||||
v6_targets: &'a [SynLimitTarget],
|
||||
}
|
||||
|
||||
impl SynLimitTargets {
|
||||
fn is_empty(&self) -> bool {
|
||||
self.iptables_v4.is_empty()
|
||||
&& self.iptables_v6.is_empty()
|
||||
&& self.nft_v4.is_empty()
|
||||
&& self.nft_v6.is_empty()
|
||||
}
|
||||
|
||||
fn has_iptables_targets(&self) -> bool {
|
||||
!self.iptables_v4.is_empty() || !self.iptables_v6.is_empty()
|
||||
}
|
||||
|
||||
fn has_nft_targets(&self) -> bool {
|
||||
!self.nft_v4.is_empty() || !self.nft_v6.is_empty()
|
||||
}
|
||||
}
|
||||
impl NftFamily {
|
||||
fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::Inet => "inet",
|
||||
Self::Ip => "ip",
|
||||
Self::Ip6 => "ip6",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn spawn_synlimit_controller(config_rx: watch::Receiver<Arc<ProxyConfig>>) {
|
||||
if !cfg!(target_os = "linux") {
|
||||
if has_synlimit_config(&config_rx.borrow()) {
|
||||
warn!("SYN limiter is configured but unsupported on this OS; skipping netfilter rules");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
tokio::spawn(async move {
|
||||
wait_for_config_channel_close_and_reconcile(config_rx).await;
|
||||
if let Err(error) = clear_synlimit_rules_all_backends().await {
|
||||
warn!(error = %error, "Failed to clear SYN limiter rules after config channel close");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn wait_for_config_channel_close_and_reconcile(
|
||||
mut config_rx: watch::Receiver<Arc<ProxyConfig>>,
|
||||
) {
|
||||
while config_rx.changed().await.is_ok() {
|
||||
let cfg = config_rx.borrow_and_update().clone();
|
||||
reconcile_synlimit_rules(&cfg).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn reconcile_synlimit_rules(cfg: &ProxyConfig) {
|
||||
if let Err(error) = clear_synlimit_rules_all_backends().await {
|
||||
warn!(error = %error, "Failed to clear existing SYN limiter rules before reconcile");
|
||||
}
|
||||
|
||||
let targets = synlimit_targets(cfg);
|
||||
if targets.is_empty() {
|
||||
return;
|
||||
}
|
||||
if !has_cap_net_admin() {
|
||||
warn!(
|
||||
"SYN limiter configured but CAP_NET_ADMIN is not available; netfilter rules not applied"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if targets.has_iptables_targets()
|
||||
&& let Err(error) = apply_iptables_synlimit_rules(&targets).await
|
||||
{
|
||||
warn!(error = %error, "Failed to apply iptables SYN limiter rules");
|
||||
}
|
||||
if targets.has_nft_targets()
|
||||
&& let Err(error) = apply_nft_synlimit_rules(&targets).await
|
||||
{
|
||||
warn!(error = %error, "Failed to apply nftables SYN limiter rules");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn clear_synlimit_rules_all_backends() -> Result<(), String> {
|
||||
let mut errors = Vec::new();
|
||||
if let Err(error) = clear_nft_synlimit_rules_all_families().await {
|
||||
errors.push(error);
|
||||
}
|
||||
if let Err(error) = clear_iptables_synlimit_rules_for_binary("iptables").await {
|
||||
errors.push(error);
|
||||
}
|
||||
if let Err(error) = clear_iptables_synlimit_rules_for_binary("ip6tables").await {
|
||||
errors.push(error);
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(errors.join("; "))
|
||||
}
|
||||
}
|
||||
|
||||
fn has_synlimit_config(cfg: &ProxyConfig) -> bool {
|
||||
cfg.server
|
||||
.listeners
|
||||
.iter()
|
||||
.any(|listener| !matches!(listener.synlimit, SynLimitMode::Off))
|
||||
}
|
||||
|
||||
fn synlimit_targets(cfg: &ProxyConfig) -> SynLimitTargets {
|
||||
let mut iptables_v4 = BTreeSet::new();
|
||||
let mut iptables_v6 = BTreeSet::new();
|
||||
let mut nft_v4 = BTreeSet::new();
|
||||
let mut nft_v6 = BTreeSet::new();
|
||||
|
||||
for listener in &cfg.server.listeners {
|
||||
let backend = listener.synlimit;
|
||||
if matches!(backend, SynLimitMode::Off) {
|
||||
continue;
|
||||
}
|
||||
let port = listener.port.unwrap_or(cfg.server.port);
|
||||
let ip = (!listener.ip.is_unspecified()).then_some(listener.ip);
|
||||
let seconds = listener.synlimit_seconds;
|
||||
let hitcount = listener.synlimit_hitcount;
|
||||
let burst = listener.synlimit_burst;
|
||||
|
||||
match (backend, listener.ip.is_ipv4()) {
|
||||
(SynLimitMode::Iptables, true) => {
|
||||
iptables_v4.insert((ip, port, seconds, hitcount, burst));
|
||||
}
|
||||
(SynLimitMode::Iptables, false) => {
|
||||
iptables_v6.insert((ip, port, seconds, hitcount, burst));
|
||||
}
|
||||
(SynLimitMode::Nftables, true) => {
|
||||
nft_v4.insert((ip, port, seconds, hitcount, burst));
|
||||
}
|
||||
(SynLimitMode::Nftables, false) => {
|
||||
nft_v6.insert((ip, port, seconds, hitcount, burst));
|
||||
}
|
||||
(SynLimitMode::Off, _) => {}
|
||||
}
|
||||
}
|
||||
|
||||
SynLimitTargets {
|
||||
iptables_v4: iptables_v4.into_iter().collect(),
|
||||
iptables_v6: iptables_v6.into_iter().collect(),
|
||||
nft_v4: nft_v4.into_iter().collect(),
|
||||
nft_v6: nft_v6.into_iter().collect(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn apply_iptables_synlimit_rules(targets: &SynLimitTargets) -> Result<(), String> {
|
||||
apply_iptables_synlimit_rules_for_binary("iptables", &targets.iptables_v4).await?;
|
||||
apply_iptables_synlimit_rules_for_binary("ip6tables", &targets.iptables_v6).await
|
||||
}
|
||||
|
||||
async fn apply_iptables_synlimit_rules_for_binary(
|
||||
binary: &str,
|
||||
targets: &[SynLimitTarget],
|
||||
) -> Result<(), String> {
|
||||
if targets.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let _ = run_command(binary, &["-t", "filter", "-N", IPTABLES_CHAIN], None).await;
|
||||
run_command(binary, &["-t", "filter", "-F", IPTABLES_CHAIN], None).await?;
|
||||
if run_command(
|
||||
binary,
|
||||
&["-t", "filter", "-C", "INPUT", "-j", IPTABLES_CHAIN],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
run_command(
|
||||
binary,
|
||||
&["-t", "filter", "-A", "INPUT", "-j", IPTABLES_CHAIN],
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
for (idx, (ip, port, seconds, hitcount, burst)) in targets.iter().enumerate() {
|
||||
let hashlimit_name = format!("{IPTABLES_HASHLIMIT_NAME}-{idx}");
|
||||
let accept_args = iptables_hashlimit_accept_rule_args(
|
||||
ip,
|
||||
*port,
|
||||
*seconds,
|
||||
*hitcount,
|
||||
*burst,
|
||||
&hashlimit_name,
|
||||
);
|
||||
let drop_args = iptables_synlimit_drop_rule_args(ip, *port);
|
||||
let drop_refs: Vec<&str> = drop_args.iter().map(String::as_str).collect();
|
||||
let accept_refs: Vec<&str> = accept_args.iter().map(String::as_str).collect();
|
||||
run_command(binary, &accept_refs, None).await?;
|
||||
run_command(binary, &drop_refs, None).await?;
|
||||
}
|
||||
run_command(
|
||||
binary,
|
||||
&["-t", "filter", "-A", IPTABLES_CHAIN, "-j", "RETURN"],
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn iptables_hashlimit_accept_rule_args(
|
||||
ip: &Option<IpAddr>,
|
||||
port: u16,
|
||||
seconds: u32,
|
||||
hitcount: u32,
|
||||
burst: u32,
|
||||
hashlimit_name: &str,
|
||||
) -> Vec<String> {
|
||||
let mut args = vec![
|
||||
"-t".to_string(),
|
||||
"filter".to_string(),
|
||||
"-A".to_string(),
|
||||
IPTABLES_CHAIN.to_string(),
|
||||
"-p".to_string(),
|
||||
"tcp".to_string(),
|
||||
"--syn".to_string(),
|
||||
];
|
||||
if let Some(ip) = ip {
|
||||
args.push("-d".to_string());
|
||||
args.push(ip.to_string());
|
||||
}
|
||||
let rate = synlimit_rate_arg(seconds, hitcount);
|
||||
args.extend([
|
||||
"--dport".to_string(),
|
||||
port.to_string(),
|
||||
"-m".to_string(),
|
||||
"hashlimit".to_string(),
|
||||
"--hashlimit-name".to_string(),
|
||||
hashlimit_name.to_string(),
|
||||
"--hashlimit-mode".to_string(),
|
||||
"srcip".to_string(),
|
||||
"--hashlimit-upto".to_string(),
|
||||
rate,
|
||||
"--hashlimit-burst".to_string(),
|
||||
burst.to_string(),
|
||||
"--hashlimit-htable-expire".to_string(),
|
||||
"15000".to_string(),
|
||||
"-j".to_string(),
|
||||
"ACCEPT".to_string(),
|
||||
]);
|
||||
args
|
||||
}
|
||||
|
||||
fn iptables_synlimit_drop_rule_args(ip: &Option<IpAddr>, port: u16) -> Vec<String> {
|
||||
let mut args = vec![
|
||||
"-t".to_string(),
|
||||
"filter".to_string(),
|
||||
"-A".to_string(),
|
||||
IPTABLES_CHAIN.to_string(),
|
||||
"-p".to_string(),
|
||||
"tcp".to_string(),
|
||||
"--syn".to_string(),
|
||||
];
|
||||
if let Some(ip) = ip {
|
||||
args.push("-d".to_string());
|
||||
args.push(ip.to_string());
|
||||
}
|
||||
args.extend([
|
||||
"--dport".to_string(),
|
||||
port.to_string(),
|
||||
"-j".to_string(),
|
||||
"DROP".to_string(),
|
||||
]);
|
||||
args
|
||||
}
|
||||
|
||||
fn synlimit_rate_arg(seconds: u32, hitcount: u32) -> String {
|
||||
let seconds = u64::from(seconds.max(1));
|
||||
let hitcount = u64::from(hitcount.max(1));
|
||||
for (unit_seconds, unit_name) in [
|
||||
(1_u64, "second"),
|
||||
(60_u64, "minute"),
|
||||
(3_600_u64, "hour"),
|
||||
(86_400_u64, "day"),
|
||||
] {
|
||||
let amount = hitcount.saturating_mul(unit_seconds);
|
||||
if amount >= seconds && amount % seconds == 0 {
|
||||
return format!("{}/{}", amount / seconds, unit_name);
|
||||
}
|
||||
}
|
||||
let amount = hitcount.saturating_mul(86_400).saturating_add(seconds - 1) / seconds;
|
||||
format!("{}/day", amount.max(1))
|
||||
}
|
||||
|
||||
async fn clear_iptables_synlimit_rules_for_binary(binary: &str) -> Result<(), String> {
|
||||
let mut errors = Vec::new();
|
||||
for _ in 0..8 {
|
||||
match run_command(
|
||||
binary,
|
||||
&["-t", "filter", "-D", "INPUT", "-j", IPTABLES_CHAIN],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {}
|
||||
Err(error) if is_missing_command_or_iptables_rule(&error) => break,
|
||||
Err(error) => {
|
||||
errors.push(format!("{binary} delete INPUT jump failed: {error}"));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Err(error) = run_command(binary, &["-t", "filter", "-F", IPTABLES_CHAIN], None).await
|
||||
&& !is_missing_command_or_iptables_rule(&error)
|
||||
{
|
||||
errors.push(format!("{binary} flush chain failed: {error}"));
|
||||
}
|
||||
if let Err(error) = run_command(binary, &["-t", "filter", "-X", IPTABLES_CHAIN], None).await
|
||||
&& !is_missing_command_or_iptables_rule(&error)
|
||||
{
|
||||
errors.push(format!("{binary} delete chain failed: {error}"));
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(errors.join(", "))
|
||||
}
|
||||
}
|
||||
|
||||
async fn apply_nft_synlimit_rules(targets: &SynLimitTargets) -> Result<(), String> {
|
||||
let families = detect_nft_table_families().await;
|
||||
for plan in nft_apply_plan(families, &targets.nft_v4, &targets.nft_v6) {
|
||||
let script = nft_synlimit_script(plan);
|
||||
run_command("nft", &["-f", "-"], Some(script)).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn detect_nft_table_families() -> NftTableFamilies {
|
||||
let Ok(output) = run_command_stdout("nft", &["list", "tables"]).await else {
|
||||
return NftTableFamilies {
|
||||
inet: false,
|
||||
ip: false,
|
||||
ip6: false,
|
||||
};
|
||||
};
|
||||
|
||||
let mut families = NftTableFamilies {
|
||||
inet: false,
|
||||
ip: false,
|
||||
ip6: false,
|
||||
};
|
||||
for line in output.lines() {
|
||||
let mut fields = line.split_whitespace();
|
||||
if fields.next() != Some("table") {
|
||||
continue;
|
||||
}
|
||||
match fields.next() {
|
||||
Some("inet") => families.inet = true,
|
||||
Some("ip") => families.ip = true,
|
||||
Some("ip6") => families.ip6 = true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
families
|
||||
}
|
||||
fn nft_apply_plan<'a>(
|
||||
families: NftTableFamilies,
|
||||
v4_targets: &'a [SynLimitTarget],
|
||||
v6_targets: &'a [SynLimitTarget],
|
||||
) -> Vec<NftApplyPlan<'a>> {
|
||||
if !v4_targets.is_empty() && !v6_targets.is_empty() {
|
||||
return vec![NftApplyPlan {
|
||||
family: NftFamily::Inet,
|
||||
v4_targets,
|
||||
v6_targets,
|
||||
}];
|
||||
}
|
||||
if !v4_targets.is_empty() {
|
||||
return vec![NftApplyPlan {
|
||||
family: if families.inet || !families.ip {
|
||||
NftFamily::Inet
|
||||
} else {
|
||||
NftFamily::Ip
|
||||
},
|
||||
v4_targets,
|
||||
v6_targets: &[],
|
||||
}];
|
||||
}
|
||||
if !v6_targets.is_empty() {
|
||||
return vec![NftApplyPlan {
|
||||
family: if families.inet || !families.ip6 {
|
||||
NftFamily::Inet
|
||||
} else {
|
||||
NftFamily::Ip6
|
||||
},
|
||||
v4_targets: &[],
|
||||
v6_targets,
|
||||
}];
|
||||
}
|
||||
Vec::new()
|
||||
}
|
||||
fn nft_synlimit_script(plan: NftApplyPlan<'_>) -> String {
|
||||
let mut script = String::new();
|
||||
script.push_str(&format!("table {} {NFT_TABLE} {{\n", plan.family.as_str()));
|
||||
script.push_str(&format!(" chain {NFT_CHAIN} {{\n"));
|
||||
script.push_str(" type filter hook input priority filter; policy accept;\n");
|
||||
for (idx, (ip, port, seconds, hitcount, burst)) in plan.v4_targets.iter().enumerate() {
|
||||
let daddr = ip
|
||||
.map(|ip| format!(" ip daddr {ip}"))
|
||||
.unwrap_or_else(String::new);
|
||||
let rate = synlimit_rate_arg(*seconds, *hitcount);
|
||||
script.push_str(&format!(
|
||||
" tcp flags & (fin|syn|rst|ack) == syn{daddr} tcp dport {port} meter telemt_synlimit_v4_{idx} {{ ip saddr limit rate over {rate} burst {burst} packets }} drop\n"
|
||||
));
|
||||
script.push_str(&format!(
|
||||
" tcp flags & (fin|syn|rst|ack) == syn{daddr} tcp dport {port} accept\n"
|
||||
));
|
||||
}
|
||||
for (idx, (ip, port, seconds, hitcount, burst)) in plan.v6_targets.iter().enumerate() {
|
||||
let daddr = ip
|
||||
.map(|ip| format!(" ip6 daddr {ip}"))
|
||||
.unwrap_or_else(String::new);
|
||||
let rate = synlimit_rate_arg(*seconds, *hitcount);
|
||||
script.push_str(&format!(
|
||||
" tcp flags & (fin|syn|rst|ack) == syn{daddr} tcp dport {port} meter telemt_synlimit_v6_{idx} {{ ip6 saddr limit rate over {rate} burst {burst} packets }} drop\n"
|
||||
));
|
||||
script.push_str(&format!(
|
||||
" tcp flags & (fin|syn|rst|ack) == syn{daddr} tcp dport {port} accept\n"
|
||||
));
|
||||
}
|
||||
script.push_str(" }\n");
|
||||
script.push_str("}\n");
|
||||
script
|
||||
}
|
||||
|
||||
async fn clear_nft_synlimit_rules_all_families() -> Result<(), String> {
|
||||
let mut errors = Vec::new();
|
||||
for family in [NftFamily::Inet, NftFamily::Ip, NftFamily::Ip6] {
|
||||
if let Err(error) = run_command(
|
||||
"nft",
|
||||
&["delete", "table", family.as_str(), NFT_TABLE],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
&& !is_missing_command_or_nft_table(&error)
|
||||
{
|
||||
errors.push(format!(
|
||||
"nft delete table {} {NFT_TABLE} failed: {error}",
|
||||
family.as_str()
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if errors.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(errors.join(", "))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_missing_command_or_iptables_rule(error: &str) -> bool {
|
||||
error.contains("is not available")
|
||||
|| error.contains("No chain/target/match by that name")
|
||||
|| error.contains("does not exist")
|
||||
}
|
||||
|
||||
fn is_missing_command_or_nft_table(error: &str) -> bool {
|
||||
error.contains("is not available") || error.contains("No such file or directory")
|
||||
}
|
||||
|
||||
async fn run_command(binary: &str, args: &[&str], stdin: Option<String>) -> Result<(), String> {
|
||||
let Some(command_path) = resolve_command(binary) else {
|
||||
return Err(format!("{binary} is not available"));
|
||||
};
|
||||
let mut command = Command::new(command_path);
|
||||
command.args(args);
|
||||
if stdin.is_some() {
|
||||
command.stdin(std::process::Stdio::piped());
|
||||
}
|
||||
command.stdout(std::process::Stdio::null());
|
||||
command.stderr(std::process::Stdio::piped());
|
||||
let mut child = command
|
||||
.spawn()
|
||||
.map_err(|e| format!("spawn {binary} failed: {e}"))?;
|
||||
if let Some(blob) = stdin
|
||||
&& let Some(mut writer) = child.stdin.take()
|
||||
{
|
||||
writer
|
||||
.write_all(blob.as_bytes())
|
||||
.await
|
||||
.map_err(|e| format!("stdin write {binary} failed: {e}"))?;
|
||||
}
|
||||
let output = child
|
||||
.wait_with_output()
|
||||
.await
|
||||
.map_err(|e| format!("wait {binary} failed: {e}"))?;
|
||||
if output.status.success() {
|
||||
return Ok(());
|
||||
}
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
|
||||
Err(if stderr.is_empty() {
|
||||
format!("{binary} exited with status {}", output.status)
|
||||
} else {
|
||||
stderr
|
||||
})
|
||||
}
|
||||
|
||||
async fn run_command_stdout(binary: &str, args: &[&str]) -> Result<String, String> {
|
||||
let Some(command_path) = resolve_command(binary) else {
|
||||
return Err(format!("{binary} is not available"));
|
||||
};
|
||||
let output = Command::new(command_path)
|
||||
.args(args)
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| format!("wait {binary} failed: {e}"))?;
|
||||
if output.status.success() {
|
||||
return Ok(String::from_utf8_lossy(&output.stdout).to_string());
|
||||
}
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
|
||||
Err(if stderr.is_empty() {
|
||||
format!("{binary} exited with status {}", output.status)
|
||||
} else {
|
||||
stderr
|
||||
})
|
||||
}
|
||||
|
||||
fn resolve_command(binary: &str) -> Option<PathBuf> {
|
||||
let mut dirs = std::env::var_os("PATH")
|
||||
.map(|path| std::env::split_paths(&path).collect::<Vec<_>>())
|
||||
.unwrap_or_default();
|
||||
dirs.extend(["/usr/sbin", "/sbin", "/usr/bin", "/bin"].map(PathBuf::from));
|
||||
dirs.into_iter()
|
||||
.map(|dir| dir.join(binary))
|
||||
.find(|candidate| candidate.exists() && candidate.is_file())
|
||||
}
|
||||
|
||||
fn has_cap_net_admin() -> bool {
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
let Ok(status) = std::fs::read_to_string("/proc/self/status") else {
|
||||
return false;
|
||||
};
|
||||
for line in status.lines() {
|
||||
if let Some(raw) = line.strip_prefix("CapEff:") {
|
||||
let caps = raw.trim();
|
||||
if let Ok(bits) = u64::from_str_radix(caps, 16) {
|
||||
const CAP_NET_ADMIN_BIT: u64 = 12;
|
||||
return (bits & (1u64 << CAP_NET_ADMIN_BIT)) != 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
{
|
||||
false
|
||||
}
|
||||
}
|
||||
@@ -155,57 +155,35 @@ fn push_fallback_size(sizes: &mut Vec<usize>, size: usize) {
|
||||
}
|
||||
|
||||
fn fallback_family_app_data_sizes(cached: &CachedTlsData) -> Vec<usize> {
|
||||
if matches!(cached.behavior_profile.source, TlsProfileSource::Rustls)
|
||||
&& !cached.app_data_records_sizes.is_empty()
|
||||
{
|
||||
return cached.app_data_records_sizes.clone();
|
||||
}
|
||||
|
||||
let family = fallback_shape_family(cached);
|
||||
let mut remaining = fallback_total_app_data_len(cached);
|
||||
let preferred_chunk = match family {
|
||||
FallbackShapeFamily::NginxLike => 2896,
|
||||
FallbackShapeFamily::BoringSslLike => 1369,
|
||||
FallbackShapeFamily::RustlsLike => 2048,
|
||||
let mut sizes = Vec::with_capacity(1);
|
||||
let size = if matches!(cached.behavior_profile.source, TlsProfileSource::Rustls) {
|
||||
cached
|
||||
.app_data_records_sizes
|
||||
.first()
|
||||
.copied()
|
||||
.unwrap_or_else(|| fallback_total_app_data_len(cached))
|
||||
} else {
|
||||
fallback_total_app_data_len(cached)
|
||||
};
|
||||
let split_threshold = match family {
|
||||
FallbackShapeFamily::NginxLike => 4096,
|
||||
FallbackShapeFamily::BoringSslLike => 1536,
|
||||
FallbackShapeFamily::RustlsLike => 3072,
|
||||
};
|
||||
|
||||
if remaining <= split_threshold {
|
||||
return vec![remaining.clamp(MIN_APP_DATA, MAX_APP_DATA)];
|
||||
}
|
||||
|
||||
let mut sizes: Vec<usize> = Vec::new();
|
||||
while remaining > 0 {
|
||||
let chunk = remaining.min(preferred_chunk).min(MAX_APP_DATA);
|
||||
if chunk < MIN_APP_DATA {
|
||||
if let Some(last) = sizes.last_mut() {
|
||||
*last = (*last).saturating_add(chunk).min(MAX_APP_DATA);
|
||||
} else {
|
||||
push_fallback_size(&mut sizes, chunk);
|
||||
}
|
||||
break;
|
||||
}
|
||||
push_fallback_size(&mut sizes, chunk);
|
||||
remaining = remaining.saturating_sub(chunk);
|
||||
}
|
||||
|
||||
push_fallback_size(&mut sizes, size);
|
||||
sizes
|
||||
}
|
||||
|
||||
fn emulated_app_data_sizes(cached: &CachedTlsData) -> Vec<usize> {
|
||||
match cached.behavior_profile.source {
|
||||
TlsProfileSource::Raw | TlsProfileSource::Merged => {
|
||||
if !cached.behavior_profile.app_data_record_sizes.is_empty() {
|
||||
return cached.behavior_profile.app_data_record_sizes.clone();
|
||||
if let Some(size) = cached.behavior_profile.app_data_record_sizes.first() {
|
||||
return vec![(*size).clamp(MIN_APP_DATA, MAX_APP_DATA)];
|
||||
}
|
||||
if !cached.app_data_records_sizes.is_empty() {
|
||||
return cached.app_data_records_sizes.clone();
|
||||
if let Some(size) = cached.app_data_records_sizes.first() {
|
||||
return vec![(*size).clamp(MIN_APP_DATA, MAX_APP_DATA)];
|
||||
}
|
||||
return vec![cached.total_app_data_len.max(1024)];
|
||||
return vec![
|
||||
cached
|
||||
.total_app_data_len
|
||||
.max(1024)
|
||||
.clamp(MIN_APP_DATA, MAX_APP_DATA),
|
||||
];
|
||||
}
|
||||
TlsProfileSource::Default | TlsProfileSource::Rustls => {
|
||||
return fallback_family_app_data_sizes(cached);
|
||||
@@ -417,7 +395,7 @@ pub fn build_emulated_server_hello(
|
||||
alpn: Option<Vec<u8>>,
|
||||
new_session_tickets: u8,
|
||||
) -> Vec<u8> {
|
||||
// --- ServerHello ---
|
||||
// ServerHello carries the authenticated digest bytes that the client verifies.
|
||||
let extensions = build_profiled_server_hello_extensions(cached, server_key_share);
|
||||
let extensions_len = extensions.len() as u16;
|
||||
|
||||
@@ -449,7 +427,7 @@ pub fn build_emulated_server_hello(
|
||||
server_hello.extend_from_slice(&(message.len() as u16).to_be_bytes());
|
||||
server_hello.extend_from_slice(&message);
|
||||
|
||||
// --- ChangeCipherSpec ---
|
||||
// ChangeCipherSpec is part of the client-visible TLS shim prefix.
|
||||
let change_cipher_spec_count = emulated_change_cipher_spec_count(cached);
|
||||
let mut change_cipher_spec = Vec::with_capacity(change_cipher_spec_count * 6);
|
||||
for _ in 0..change_cipher_spec_count {
|
||||
@@ -463,7 +441,8 @@ pub fn build_emulated_server_hello(
|
||||
]);
|
||||
}
|
||||
|
||||
// --- ApplicationData (fake encrypted records) ---
|
||||
// Telegram clients authenticate the hello prefix and then expose any later
|
||||
// ApplicationData bytes to the MTProto packet parser.
|
||||
let mut sizes = {
|
||||
let base_sizes = emulated_app_data_sizes(cached);
|
||||
match cached.behavior_profile.source {
|
||||
@@ -550,8 +529,7 @@ pub fn build_emulated_server_hello(
|
||||
app_data.extend_from_slice(&rec);
|
||||
}
|
||||
|
||||
// --- Combine ---
|
||||
// Optional NewSessionTicket mimic records (opaque ApplicationData for fingerprint).
|
||||
// Optional NewSessionTicket mimic records are an explicit fingerprint opt-in.
|
||||
let mut tickets = Vec::new();
|
||||
for ticket_len in emulated_ticket_record_sizes(cached, new_session_tickets, rng) {
|
||||
let mut rec = Vec::with_capacity(5 + ticket_len);
|
||||
@@ -570,7 +548,7 @@ pub fn build_emulated_server_hello(
|
||||
response.extend_from_slice(&app_data);
|
||||
response.extend_from_slice(&tickets);
|
||||
|
||||
// --- HMAC ---
|
||||
// The digest authenticates the server response bytes emitted by this builder.
|
||||
let mut hmac_input = Vec::with_capacity(TLS_DIGEST_LEN + response.len());
|
||||
hmac_input.extend_from_slice(client_digest);
|
||||
hmac_input.extend_from_slice(&response);
|
||||
@@ -1062,7 +1040,7 @@ mod tests {
|
||||
app_lens.push(record_len);
|
||||
pos += 5 + record_len;
|
||||
}
|
||||
assert_eq!(app_lens, vec![64, 3905, 537]);
|
||||
assert_eq!(app_lens, vec![64]);
|
||||
assert_eq!(pos, response.len());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,7 +106,37 @@ fn emulated_server_hello_does_not_emit_profile_ticket_tail_when_disabled() {
|
||||
);
|
||||
|
||||
let app_records = record_lengths_by_type(&response, TLS_RECORD_APPLICATION);
|
||||
assert_eq!(app_records, vec![1200, 900]);
|
||||
assert_eq!(app_records, vec![1200]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn emulated_server_hello_keeps_default_profile_primary_app_data_single() {
|
||||
let mut cached = make_cached();
|
||||
cached.behavior_profile.source = TlsProfileSource::Default;
|
||||
cached.behavior_profile.app_data_record_sizes.clear();
|
||||
cached.behavior_profile.ticket_record_sizes.clear();
|
||||
cached.app_data_records_sizes = vec![2048, 1024];
|
||||
cached.total_app_data_len = 5000;
|
||||
let rng = SecureRandom::new();
|
||||
|
||||
let response = build_emulated_server_hello(
|
||||
b"secret",
|
||||
&[0x85; 32],
|
||||
&[0x86; 16],
|
||||
&cached,
|
||||
false,
|
||||
true,
|
||||
ClientHelloTlsVersion::Tls13,
|
||||
[0x13, 0x01],
|
||||
&test_server_key_share(),
|
||||
&rng,
|
||||
None,
|
||||
0,
|
||||
);
|
||||
|
||||
let app_records = record_lengths_by_type(&response, TLS_RECORD_APPLICATION);
|
||||
assert_eq!(app_records.len(), 1);
|
||||
assert!(app_records[0] >= 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -130,5 +160,5 @@ fn emulated_server_hello_uses_profile_ticket_lengths_when_enabled() {
|
||||
);
|
||||
|
||||
let app_records = record_lengths_by_type(&response, TLS_RECORD_APPLICATION);
|
||||
assert_eq!(app_records, vec![1200, 900, 220, 180]);
|
||||
assert_eq!(app_records, vec![1200, 220, 180]);
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ use tokio::time::timeout;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::config::MeSocksKdfPolicy;
|
||||
use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256};
|
||||
use crate::crypto::{SecureRandom, derive_middleproxy_keys};
|
||||
use crate::error::{ProxyError, Result};
|
||||
use crate::network::IpFamily;
|
||||
use crate::network::probe::is_bogon;
|
||||
@@ -292,14 +292,17 @@ impl MePool {
|
||||
BndPortStatus::Error
|
||||
};
|
||||
record_bnd_status(bnd_addr_status, bnd_port_status, raw_socks_bound_addr);
|
||||
let reflected = if let Some(bound) = socks_bound_addr {
|
||||
let socks_bound_kdf_addr = socks_bound_addr.filter(|bound| bound.port() != 0);
|
||||
// SOCKS BND is the only reflected source that can supply both KDF IP and
|
||||
// port. Direct STUN reflection is IP-only and keeps the TCP local port.
|
||||
let reflected = if let Some(bound) = socks_bound_kdf_addr {
|
||||
Some(bound)
|
||||
} else if is_socks_route {
|
||||
match self.socks_kdf_policy() {
|
||||
MeSocksKdfPolicy::Strict => {
|
||||
self.stats.increment_me_socks_kdf_strict_reject();
|
||||
return Err(ProxyError::InvalidHandshake(
|
||||
"SOCKS route returned no valid BND.ADDR for ME KDF (strict policy)"
|
||||
"SOCKS route returned no valid BND tuple for ME KDF (strict policy)"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
@@ -323,16 +326,14 @@ impl MePool {
|
||||
let local_addr_nat = self.translate_our_addr_with_reflection(local_addr, reflected);
|
||||
let peer_addr_nat =
|
||||
SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port());
|
||||
let client_addr_for_kdf = socks_bound_kdf_addr.unwrap_or(local_addr_nat);
|
||||
if let Some(upstream_info) = upstream_egress {
|
||||
let client_ip_for_kdf = socks_bound_addr
|
||||
.map(|value| value.ip())
|
||||
.unwrap_or(local_addr_nat.ip());
|
||||
record_upstream_bnd_status(
|
||||
upstream_info.upstream_id,
|
||||
bnd_addr_status,
|
||||
bnd_port_status,
|
||||
raw_socks_bound_addr,
|
||||
Some(client_ip_for_kdf),
|
||||
Some(client_addr_for_kdf.ip()),
|
||||
);
|
||||
}
|
||||
let (mut rd, mut wr) = tokio::io::split(stream);
|
||||
@@ -409,6 +410,7 @@ impl MePool {
|
||||
info!(
|
||||
%local_addr,
|
||||
%local_addr_nat,
|
||||
%client_addr_for_kdf,
|
||||
reflected_ip = reflected.map(|r| r.ip()).as_ref().map(ToString::to_string),
|
||||
%peer_addr,
|
||||
%transport_peer_addr,
|
||||
@@ -417,21 +419,20 @@ impl MePool {
|
||||
key_selector = format_args!("0x{ks:08x}"),
|
||||
crypto_schema = format_args!("0x{schema:08x}"),
|
||||
skew_secs = skew,
|
||||
socks_kdf_policy = ?self.socks_kdf_policy(),
|
||||
"ME key derivation parameters"
|
||||
);
|
||||
|
||||
let ts_bytes = crypto_ts.to_le_bytes();
|
||||
let server_port_bytes = peer_addr_nat.port().to_le_bytes();
|
||||
let socks_bound_port = socks_bound_addr
|
||||
.map(|bound| bound.port())
|
||||
.filter(|port| *port != 0);
|
||||
let client_port_for_kdf = socks_bound_port.unwrap_or(local_addr_nat.port());
|
||||
let socks_bound_port = socks_bound_kdf_addr.map(|bound| bound.port());
|
||||
let client_port_for_kdf = client_addr_for_kdf.port();
|
||||
let client_port_source = KdfClientPortSource::from_socks_bound_port(socks_bound_port);
|
||||
let kdf_fingerprint = Self::kdf_material_fingerprint(
|
||||
local_addr_nat.ip(),
|
||||
client_addr_for_kdf.ip(),
|
||||
peer_addr_nat,
|
||||
reflected.map(|value| value.ip()),
|
||||
socks_bound_addr.map(|value| value.ip()),
|
||||
socks_bound_kdf_addr.map(|value| value.ip()),
|
||||
client_port_source,
|
||||
);
|
||||
let previous_kdf_fingerprint = {
|
||||
@@ -473,7 +474,7 @@ impl MePool {
|
||||
let client_port_bytes = client_port_for_kdf.to_le_bytes();
|
||||
|
||||
let server_ip = extract_ip_material(peer_addr_nat);
|
||||
let client_ip = extract_ip_material(local_addr_nat);
|
||||
let client_ip = extract_ip_material(client_addr_for_kdf);
|
||||
|
||||
let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) =
|
||||
match (server_ip, client_ip) {
|
||||
@@ -494,38 +495,6 @@ impl MePool {
|
||||
}
|
||||
};
|
||||
|
||||
let diag_level: u8 = std::env::var("ME_DIAG")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
let prekey_client = build_middleproxy_prekey(
|
||||
&srv_nonce,
|
||||
&my_nonce,
|
||||
&ts_bytes,
|
||||
srv_ip_opt.as_ref().map(|x| &x[..]),
|
||||
&client_port_bytes,
|
||||
b"CLIENT",
|
||||
clt_ip_opt.as_ref().map(|x| &x[..]),
|
||||
&server_port_bytes,
|
||||
&secret,
|
||||
clt_v6_opt.as_ref(),
|
||||
srv_v6_opt.as_ref(),
|
||||
);
|
||||
let prekey_server = build_middleproxy_prekey(
|
||||
&srv_nonce,
|
||||
&my_nonce,
|
||||
&ts_bytes,
|
||||
srv_ip_opt.as_ref().map(|x| &x[..]),
|
||||
&client_port_bytes,
|
||||
b"SERVER",
|
||||
clt_ip_opt.as_ref().map(|x| &x[..]),
|
||||
&server_port_bytes,
|
||||
&secret,
|
||||
clt_v6_opt.as_ref(),
|
||||
srv_v6_opt.as_ref(),
|
||||
);
|
||||
|
||||
let (wk, wi) = derive_middleproxy_keys(
|
||||
&srv_nonce,
|
||||
&my_nonce,
|
||||
@@ -556,47 +525,14 @@ impl MePool {
|
||||
let requested_crc_mode = RpcChecksumMode::Crc32c;
|
||||
let hs_payload = build_handshake_payload(
|
||||
hs_our_ip,
|
||||
local_addr.port(),
|
||||
client_port_for_kdf,
|
||||
hs_peer_ip,
|
||||
peer_addr.port(),
|
||||
peer_addr_nat.port(),
|
||||
requested_crc_mode.advertised_flags(),
|
||||
);
|
||||
let hs_frame = build_rpc_frame(-1, &hs_payload, RpcChecksumMode::Crc32);
|
||||
if diag_level >= 1 {
|
||||
info!(
|
||||
write_key = %hex_dump(&wk),
|
||||
write_iv = %hex_dump(&wi),
|
||||
read_key = %hex_dump(&rk),
|
||||
read_iv = %hex_dump(&ri),
|
||||
srv_ip = %srv_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(),
|
||||
clt_ip = %clt_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(),
|
||||
srv_port = %hex_dump(&server_port_bytes),
|
||||
clt_port = %hex_dump(&client_port_bytes),
|
||||
crypto_ts = %hex_dump(&ts_bytes),
|
||||
nonce_srv = %hex_dump(&srv_nonce),
|
||||
nonce_clt = %hex_dump(&my_nonce),
|
||||
prekey_sha256_client = %hex_dump(&sha256(&prekey_client)),
|
||||
prekey_sha256_server = %hex_dump(&sha256(&prekey_server)),
|
||||
hs_plain = %hex_dump(&hs_frame),
|
||||
proxy_secret_sha256 = %hex_dump(&sha256(&secret)),
|
||||
"ME diag: derived keys and handshake plaintext"
|
||||
);
|
||||
}
|
||||
if diag_level >= 2 {
|
||||
info!(
|
||||
prekey_client = %hex_dump(&prekey_client),
|
||||
prekey_server = %hex_dump(&prekey_server),
|
||||
"ME diag: full prekey buffers"
|
||||
);
|
||||
}
|
||||
|
||||
let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?;
|
||||
if diag_level >= 1 {
|
||||
info!(
|
||||
hs_cipher = %hex_dump(&encrypted_hs),
|
||||
"ME diag: handshake ciphertext"
|
||||
);
|
||||
}
|
||||
wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?;
|
||||
wr.flush().await.map_err(ProxyError::Io)?;
|
||||
|
||||
|
||||
@@ -1728,6 +1728,8 @@ mod tests {
|
||||
false,
|
||||
None,
|
||||
Vec::new(),
|
||||
false,
|
||||
Vec::new(),
|
||||
1,
|
||||
None,
|
||||
12,
|
||||
|
||||
@@ -336,6 +336,8 @@ pub(super) struct NatRuntimeCore {
|
||||
pub(super) nat_probe: bool,
|
||||
pub(super) nat_stun: Option<String>,
|
||||
pub(super) nat_stun_servers: Vec<String>,
|
||||
pub(super) stun_tcp_fallback: bool,
|
||||
pub(super) http_ip_detect_urls: Vec<String>,
|
||||
pub(super) nat_stun_live_servers: Arc<RwLock<Vec<String>>>,
|
||||
pub(super) nat_probe_concurrency: usize,
|
||||
pub(super) detected_ipv6: Option<Ipv6Addr>,
|
||||
@@ -484,6 +486,8 @@ impl MePool {
|
||||
nat_probe: bool,
|
||||
nat_stun: Option<String>,
|
||||
nat_stun_servers: Vec<String>,
|
||||
stun_tcp_fallback: bool,
|
||||
http_ip_detect_urls: Vec<String>,
|
||||
nat_probe_concurrency: usize,
|
||||
detected_ipv6: Option<Ipv6Addr>,
|
||||
me_one_retry: u8,
|
||||
@@ -706,6 +710,8 @@ impl MePool {
|
||||
nat_probe,
|
||||
nat_stun,
|
||||
nat_stun_servers,
|
||||
stun_tcp_fallback,
|
||||
http_ip_detect_urls,
|
||||
nat_stun_live_servers: Arc::new(RwLock::new(Vec::new())),
|
||||
nat_probe_concurrency: nat_probe_concurrency.max(1),
|
||||
detected_ipv6,
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::net::IpAddr;
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::timeout;
|
||||
use tracing::{debug, info, warn};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::error::{ProxyError, Result};
|
||||
use crate::network::probe::is_bogon;
|
||||
use crate::network::stun::{IpFamily, stun_probe_dual, stun_probe_family_with_bind};
|
||||
use crate::network::probe::{detect_public_ipv4_http, is_bogon};
|
||||
use crate::network::stun::{
|
||||
IpFamily, stun_probe_dual_with_tcp_fallback, stun_probe_family_with_bind_and_tcp_fallback,
|
||||
};
|
||||
|
||||
use super::MePool;
|
||||
use std::time::Instant;
|
||||
|
||||
const STUN_BATCH_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
const STUN_BATCH_TCP_FALLBACK_TIMEOUT: Duration = Duration::from_secs(12);
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn stun_probe(stun_addr: Option<String>) -> Result<crate::network::stun::DualStunResult> {
|
||||
@@ -28,16 +31,13 @@ pub async fn stun_probe(stun_addr: Option<String>) -> Result<crate::network::stu
|
||||
"STUN server is not configured".to_string(),
|
||||
));
|
||||
}
|
||||
stun_probe_dual(&stun_addr).await
|
||||
stun_probe_dual_with_tcp_fallback(&stun_addr, false).await
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn detect_public_ip() -> Option<IpAddr> {
|
||||
fetch_public_ipv4_with_retry()
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.map(IpAddr::V4)
|
||||
let urls = crate::config::defaults::default_http_ip_detect_urls();
|
||||
detect_public_ipv4_http(&urls).await.map(IpAddr::V4)
|
||||
}
|
||||
|
||||
impl MePool {
|
||||
@@ -65,15 +65,26 @@ impl MePool {
|
||||
let mut live_servers = Vec::new();
|
||||
let mut best_by_ip: HashMap<IpAddr, (usize, std::net::SocketAddr)> = HashMap::new();
|
||||
let concurrency = self.nat_runtime.nat_probe_concurrency.max(1);
|
||||
let tcp_fallback = self.nat_runtime.stun_tcp_fallback;
|
||||
|
||||
while next_idx < servers.len() || !join_set.is_empty() {
|
||||
while next_idx < servers.len() && join_set.len() < concurrency {
|
||||
let stun_addr = servers[next_idx].clone();
|
||||
next_idx += 1;
|
||||
join_set.spawn(async move {
|
||||
let batch_timeout = if tcp_fallback {
|
||||
STUN_BATCH_TCP_FALLBACK_TIMEOUT
|
||||
} else {
|
||||
STUN_BATCH_TIMEOUT
|
||||
};
|
||||
let res = timeout(
|
||||
STUN_BATCH_TIMEOUT,
|
||||
stun_probe_family_with_bind(&stun_addr, family, bind_ip),
|
||||
batch_timeout,
|
||||
stun_probe_family_with_bind_and_tcp_fallback(
|
||||
&stun_addr,
|
||||
family,
|
||||
bind_ip,
|
||||
tcp_fallback,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
(stun_addr, res)
|
||||
@@ -193,6 +204,10 @@ impl MePool {
|
||||
return self.nat_runtime.nat_ip_cfg;
|
||||
}
|
||||
|
||||
if !self.nat_runtime.nat_probe {
|
||||
return None;
|
||||
}
|
||||
|
||||
if !(is_bogon(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) {
|
||||
return None;
|
||||
}
|
||||
@@ -201,21 +216,15 @@ impl MePool {
|
||||
return Some(ip);
|
||||
}
|
||||
|
||||
match fetch_public_ipv4_with_retry().await {
|
||||
Ok(Some(ip)) => {
|
||||
{
|
||||
let mut guard = self.nat_runtime.nat_ip_detected.write().await;
|
||||
*guard = Some(IpAddr::V4(ip));
|
||||
}
|
||||
info!(public_ip = %ip, "Auto-detected public IP for NAT translation");
|
||||
Some(IpAddr::V4(ip))
|
||||
}
|
||||
Ok(None) => None,
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Failed to auto-detect public IP");
|
||||
None
|
||||
}
|
||||
let Some(ip) = detect_public_ipv4_http(&self.nat_runtime.http_ip_detect_urls).await else {
|
||||
return None;
|
||||
};
|
||||
{
|
||||
let mut guard = self.nat_runtime.nat_ip_detected.write().await;
|
||||
*guard = Some(IpAddr::V4(ip));
|
||||
}
|
||||
info!(public_ip = %ip, "Auto-detected public IP for NAT translation");
|
||||
Some(IpAddr::V4(ip))
|
||||
}
|
||||
|
||||
pub(super) async fn maybe_reflect_public_addr(
|
||||
@@ -365,31 +374,3 @@ impl MePool {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_public_ipv4_with_retry() -> Result<Option<Ipv4Addr>> {
|
||||
let providers = [
|
||||
"https://checkip.amazonaws.com",
|
||||
"http://v4.ident.me",
|
||||
"http://ipv4.icanhazip.com",
|
||||
];
|
||||
for url in providers {
|
||||
if let Ok(Some(ip)) = fetch_public_ipv4_once(url).await {
|
||||
return Ok(Some(ip));
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn fetch_public_ipv4_once(url: &str) -> Result<Option<Ipv4Addr>> {
|
||||
let res = reqwest::get(url)
|
||||
.await
|
||||
.map_err(|e| ProxyError::Proxy(format!("public IP detection request failed: {e}")))?;
|
||||
|
||||
let text = res
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| ProxyError::Proxy(format!("public IP detection read failed: {e}")))?;
|
||||
|
||||
let ip = text.trim().parse().ok();
|
||||
Ok(ip)
|
||||
}
|
||||
|
||||
@@ -464,8 +464,7 @@ impl MePool {
|
||||
if !self.writer_accepts_new_binding(w) {
|
||||
continue;
|
||||
}
|
||||
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
|
||||
let (payload, meta) = build_routed_payload(effective_our_addr);
|
||||
let (payload, meta) = build_routed_payload(our_addr);
|
||||
match w.tx.clone().try_reserve_owned() {
|
||||
Ok(permit) => {
|
||||
if !self.registry.bind_writer(conn_id, w.id, meta).await {
|
||||
@@ -520,8 +519,7 @@ impl MePool {
|
||||
}
|
||||
self.stats
|
||||
.increment_me_writer_pick_blocking_fallback_total();
|
||||
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
|
||||
let (payload, meta) = build_routed_payload(effective_our_addr);
|
||||
let (payload, meta) = build_routed_payload(our_addr);
|
||||
let reserve_result =
|
||||
if let Some(timeout) = self.route_runtime.me_route_blocking_send_timeout {
|
||||
match tokio::time::timeout(timeout, w.tx.clone().reserve_owned()).await {
|
||||
|
||||
@@ -38,6 +38,8 @@ async fn make_pool(
|
||||
false,
|
||||
None,
|
||||
Vec::new(),
|
||||
false,
|
||||
Vec::new(),
|
||||
1,
|
||||
None,
|
||||
12,
|
||||
|
||||
@@ -36,6 +36,8 @@ async fn make_pool(
|
||||
false,
|
||||
None,
|
||||
Vec::new(),
|
||||
false,
|
||||
Vec::new(),
|
||||
1,
|
||||
None,
|
||||
12,
|
||||
|
||||
@@ -31,6 +31,8 @@ async fn make_pool(me_pool_drain_threshold: u64) -> Arc<MePool> {
|
||||
false,
|
||||
None,
|
||||
Vec::new(),
|
||||
false,
|
||||
Vec::new(),
|
||||
1,
|
||||
None,
|
||||
12,
|
||||
|
||||
@@ -20,6 +20,8 @@ async fn make_pool() -> Arc<MePool> {
|
||||
false,
|
||||
None,
|
||||
Vec::new(),
|
||||
false,
|
||||
Vec::new(),
|
||||
1,
|
||||
None,
|
||||
12,
|
||||
|
||||
@@ -25,6 +25,8 @@ async fn make_pool() -> Arc<MePool> {
|
||||
false,
|
||||
None,
|
||||
Vec::new(),
|
||||
false,
|
||||
Vec::new(),
|
||||
1,
|
||||
None,
|
||||
12,
|
||||
|
||||
@@ -31,6 +31,8 @@ async fn make_pool() -> (Arc<MePool>, Arc<SecureRandom>) {
|
||||
false,
|
||||
None,
|
||||
Vec::new(),
|
||||
false,
|
||||
Vec::new(),
|
||||
1,
|
||||
None,
|
||||
12,
|
||||
@@ -175,6 +177,37 @@ async fn recv_data_count(rx: &mut mpsc::Receiver<WriterCommand>, budget: Duratio
|
||||
data_count
|
||||
}
|
||||
|
||||
async fn recv_first_data_payload(
|
||||
rx: &mut mpsc::Receiver<WriterCommand>,
|
||||
budget: Duration,
|
||||
) -> Option<Vec<u8>> {
|
||||
let start = Instant::now();
|
||||
while Instant::now().duration_since(start) < budget {
|
||||
let remaining = budget.saturating_sub(Instant::now().duration_since(start));
|
||||
match tokio::time::timeout(remaining.min(Duration::from_millis(10)), rx.recv()).await {
|
||||
Ok(Some(WriterCommand::Data(payload))) => return Some(payload.to_vec()),
|
||||
Ok(Some(WriterCommand::DataAndFlush(payload))) => return Some(payload.to_vec()),
|
||||
Ok(Some(_)) => {}
|
||||
Ok(None) => break,
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn proxy_req_our_addr_from_payload(payload: &[u8]) -> SocketAddr {
|
||||
const CLIENT_ADDR_WIRE_LEN: usize = 20;
|
||||
const OUR_ADDR_OFFSET: usize = 4 + 4 + 8 + CLIENT_ADDR_WIRE_LEN;
|
||||
|
||||
let our_addr = &payload[OUR_ADDR_OFFSET..OUR_ADDR_OFFSET + CLIENT_ADDR_WIRE_LEN];
|
||||
let ip = Ipv4Addr::new(our_addr[12], our_addr[13], our_addr[14], our_addr[15]);
|
||||
let port = u32::from_le_bytes([our_addr[16], our_addr[17], our_addr[18], our_addr[19]]);
|
||||
SocketAddr::new(
|
||||
IpAddr::V4(ip),
|
||||
u16::try_from(port).expect("test port must fit u16"),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_proxy_req_does_not_replay_when_first_bind_commit_fails() {
|
||||
let (pool, _rng) = make_pool().await;
|
||||
@@ -288,3 +321,47 @@ async fn send_proxy_req_prunes_iterative_stale_bind_failures_without_data_replay
|
||||
drop(writers);
|
||||
assert_eq!(writer_ids, vec![23]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn send_proxy_req_preserves_client_facing_our_addr_when_writer_source_ip_differs() {
|
||||
let (pool, _rng) = make_pool().await;
|
||||
pool.rr.store(0, Ordering::Relaxed);
|
||||
|
||||
let (conn_id, _rx) = pool.registry.register().await;
|
||||
let mut live_rx = insert_writer(
|
||||
&pool,
|
||||
31,
|
||||
2,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 2, 31)), 443),
|
||||
true,
|
||||
)
|
||||
.await;
|
||||
|
||||
{
|
||||
let mut writers = pool.writers.write().await;
|
||||
let writer = writers
|
||||
.iter_mut()
|
||||
.find(|writer| writer.id == 31)
|
||||
.expect("test writer must exist");
|
||||
writer.source_ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 31));
|
||||
}
|
||||
|
||||
let our_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 100, 7)), 8443);
|
||||
let result = pool
|
||||
.send_proxy_req(
|
||||
conn_id,
|
||||
2,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 7)), 30002),
|
||||
our_addr,
|
||||
b"route",
|
||||
0,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let payload = recv_first_data_payload(&mut live_rx, Duration::from_millis(50))
|
||||
.await
|
||||
.expect("writer must receive routed payload");
|
||||
assert_eq!(proxy_req_our_addr_from_payload(&payload), our_addr);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user