From 613e5d54100bb5517a3b39dbe6a2a2f7e7569620 Mon Sep 17 00:00:00 2001 From: Yaroslav Gurov Date: Tue, 23 Sep 2025 23:57:02 +0200 Subject: [PATCH] feat: implement S3-S4 junked offsets * rework junked headers sending * get rid of redundant allocations --- src/device.c | 49 +++++++++++++++++++++++-------------------------- src/device.h | 3 +-- src/netlink.c | 28 ++++++++++++++++++++-------- src/receive.c | 22 ++++++++++++++-------- src/send.c | 33 ++++++++++----------------------- src/socket.c | 26 ++++++++++---------------- src/socket.h | 6 ++---- 7 files changed, 80 insertions(+), 87 deletions(-) diff --git a/src/device.c b/src/device.c index bf85bf6..3557643 100644 --- a/src/device.c +++ b/src/device.c @@ -540,26 +540,32 @@ int wg_device_handle_post_config(struct net_device *dev, struct amnezia_config * if (asc->junk_packet_max_size != 0) a_sec_on = true; - if (asc->init_packet_junk_size + MESSAGE_INITIATION_SIZE >= MESSAGE_MAX_SIZE) { - net_dbg_ratelimited("%s: init header size (%d) + junkSize (%d) should be smaller than maxSegmentSize: %d\n", - dev->name, MESSAGE_INITIATION_SIZE, - asc->init_packet_junk_size, MESSAGE_MAX_SIZE); - ret = -EINVAL; - } else - wg->advanced_security_config.init_packet_junk_size = asc->init_packet_junk_size; - - if (asc->init_packet_junk_size != 0) + if (wg->junk_size[MSGIDX_HANDSHAKE_INIT] + MESSAGE_INITIATION_SIZE > MESSAGE_MAX_SIZE) { + net_dbg_ratelimited("%s: S1 is too large\n", wg->dev->name); + err = -EINVAL; + } + else a_sec_on = true; - if (asc->response_packet_junk_size + MESSAGE_RESPONSE_SIZE >= MESSAGE_MAX_SIZE) { - net_dbg_ratelimited("%s: response header size (%d) + junkSize (%d) should be smaller than maxSegmentSize: %d\n", - dev->name, MESSAGE_RESPONSE_SIZE, - asc->response_packet_junk_size, MESSAGE_MAX_SIZE); - ret = -EINVAL; - } else - wg->advanced_security_config.response_packet_junk_size = asc->response_packet_junk_size; + if (wg->junk_size[MSGIDX_HANDSHAKE_RESPONSE] + MESSAGE_RESPONSE_SIZE > MESSAGE_MAX_SIZE) { + net_dbg_ratelimited("%s: S2 is too large\n", wg->dev->name); + err = -EINVAL; + } + else + a_sec_on = true; - if (asc->response_packet_junk_size != 0) + if (wg->junk_size[MSGIDX_HANDSHAKE_COOKIE] + MESSAGE_COOKIE_REPLY_SIZE > MESSAGE_MAX_SIZE) { + net_dbg_ratelimited("%s: S3 is too large\n", wg->dev->name); + err = -EINVAL; + } + else + a_sec_on = true; + + if (wg->junk_size[MSGIDX_TRANSPORT] + MESSAGE_TRANSPORT_SIZE > MESSAGE_MAX_SIZE) { + net_dbg_ratelimited("%s: S4 is too large\n", wg->dev->name); + err = -EINVAL; + } + else a_sec_on = true; for (i = 0; i < ARRAY_SIZE(wg->headers); ++i) { @@ -573,15 +579,6 @@ int wg_device_handle_post_config(struct net_device *dev, struct amnezia_config * } } - if (MESSAGE_INITIATION_SIZE + wg->advanced_security_config.init_packet_junk_size == - MESSAGE_RESPONSE_SIZE + wg->advanced_security_config.response_packet_junk_size) { - net_dbg_ratelimited("%s: new init size:%d; and new response size:%d; should differ\n", - dev->name, - MESSAGE_INITIATION_SIZE + asc->init_packet_junk_size, - MESSAGE_RESPONSE_SIZE + asc->response_packet_junk_size); - ret = -EINVAL; - } - wg->advanced_security_config.advanced_security = a_sec_on; out: return ret; diff --git a/src/device.h b/src/device.h index 8397a89..a9c701c 100644 --- a/src/device.h +++ b/src/device.h @@ -43,8 +43,6 @@ struct amnezia_config { u16 junk_packet_count; u16 junk_packet_min_size; u16 junk_packet_max_size; - u16 init_packet_junk_size; - u16 response_packet_junk_size; }; struct wg_device { @@ -67,6 +65,7 @@ struct wg_device { u16 incoming_port; struct magic_header headers[4]; + u16 junk_size[4]; }; int wg_device_init(void); diff --git a/src/netlink.c b/src/netlink.c index 3230103..62fed79 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -51,7 +51,9 @@ static const struct nla_policy device_policy[WGDEVICE_A_MAX + 1] = { [WGDEVICE_A_H2] = { .type = NLA_NUL_STRING }, [WGDEVICE_A_H3] = { .type = NLA_NUL_STRING }, [WGDEVICE_A_H4] = { .type = NLA_NUL_STRING }, - [WGDEVICE_A_PEER] = { .type = NLA_NESTED } + [WGDEVICE_A_PEER] = { .type = NLA_NESTED }, + [WGDEVICE_A_S3] = { .type = NLA_U16 }, + [WGDEVICE_A_S4] = { .type = NLA_U16 } }; static const struct nla_policy peer_policy[WGPEER_A_MAX + 1] = { @@ -434,10 +436,8 @@ static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) wg->advanced_security_config.junk_packet_min_size) || nla_put_u16(skb, WGDEVICE_A_JMAX, wg->advanced_security_config.junk_packet_max_size) || - nla_put_u16(skb, WGDEVICE_A_S1, - wg->advanced_security_config.init_packet_junk_size) || - nla_put_u16(skb, WGDEVICE_A_S2, - wg->advanced_security_config.response_packet_junk_size) || + nla_put_u16(skb, WGDEVICE_A_S1, wg->junk_size[MSGIDX_HANDSHAKE_INIT]) || + nla_put_u16(skb, WGDEVICE_A_S2,wg->junk_size[MSGIDX_HANDSHAKE_RESPONSE]) || (mh_genspec(&wg->headers[MSGIDX_HANDSHAKE_INIT], buf, sizeof(buf)) && nla_put_string(skb, WGDEVICE_A_H1, buf)) || (mh_genspec(&wg->headers[MSGIDX_HANDSHAKE_RESPONSE], buf, sizeof(buf)) && @@ -445,7 +445,9 @@ static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) (mh_genspec(&wg->headers[MSGIDX_HANDSHAKE_COOKIE], buf, sizeof(buf)) && nla_put_string(skb, WGDEVICE_A_H3, buf)) || (mh_genspec(&wg->headers[MSGIDX_TRANSPORT], buf, sizeof(buf)) && - nla_put_string(skb, WGDEVICE_A_H4, buf)) + nla_put_string(skb, WGDEVICE_A_H4, buf)) || + nla_put_u16(skb, WGDEVICE_A_S3, wg->junk_size[MSGIDX_HANDSHAKE_COOKIE]) || + nla_put_u16(skb, WGDEVICE_A_S4, wg->junk_size[MSGIDX_TRANSPORT]) goto out; down_read(&wg->static_identity.lock); @@ -774,12 +776,12 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info) if (info->attrs[WGDEVICE_A_S1]) { asc->advanced_security = true; - asc->init_packet_junk_size = nla_get_u16(info->attrs[WGDEVICE_A_S1]); + wg->junk_size[MSGIDX_HANDSHAKE_INIT] = nla_get_u16(info->attrs[WGDEVICE_A_S1]); } if (info->attrs[WGDEVICE_A_S2]) { asc->advanced_security = true; - asc->response_packet_junk_size = nla_get_u16(info->attrs[WGDEVICE_A_S2]); + wg->junk_size[MSGIDX_HANDSHAKE_RESPONSE] = nla_get_u16(info->attrs[WGDEVICE_A_S2]); } if (info->attrs[WGDEVICE_A_H1]) { @@ -818,6 +820,16 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info) goto out; } + if (info->attrs[WGDEVICE_A_S3]) { + asc->advanced_security = true; + wg->junk_size[MSGIDX_HANDSHAKE_COOKIE] = nla_get_u16(info->attrs[WGDEVICE_A_S3]); + } + + if (info->attrs[WGDEVICE_A_S4]) { + asc->advanced_security = true; + wg->junk_size[MSGIDX_TRANSPORT] = nla_get_u16(info->attrs[WGDEVICE_A_S4]); + } + if (flags & WGDEVICE_F_REPLACE_PEERS) wg_peer_remove_all(wg); diff --git a/src/receive.c b/src/receive.c index ecfee11..10525f9 100644 --- a/src/receive.c +++ b/src/receive.c @@ -33,30 +33,36 @@ static size_t prepare_awg_message(struct sk_buff *skb, struct wg_device *wg) return 0; } - if (skb->len == wg->advanced_security_config.init_packet_junk_size + MESSAGE_INITIATION_SIZE) { - skb_pull(skb, wg->advanced_security_config.init_packet_junk_size); + if (skb->len == wg->junk_size[MSGIDX_HANDSHAKE_INIT] + MESSAGE_INITIATION_SIZE) { + skb_pull(skb, wg->junk_size[MSGIDX_HANDSHAKE_INIT]); if (mh_validate(SKB_TYPE_LE32(skb), &wg->headers[MSGIDX_HANDSHAKE_INIT])) return MESSAGE_INITIATION_SIZE; else - skb_push(skb, wg->advanced_security_config.init_packet_junk_size); + skb_push(skb, wg->junk_size[MSGIDX_HANDSHAKE_INIT]); } - if (skb->len == wg->advanced_security_config.response_packet_junk_size + MESSAGE_RESPONSE_SIZE) { - skb_pull(skb, wg->advanced_security_config.response_packet_junk_size); + if (skb->len == wg->junk_size[MSGIDX_HANDSHAKE_RESPONSE] + MESSAGE_RESPONSE_SIZE) { + skb_pull(skb, wg->junk_size[MSGIDX_HANDSHAKE_RESPONSE]); if (mh_validate(SKB_TYPE_LE32(skb), &wg->headers[MSGIDX_HANDSHAKE_RESPONSE])) return MESSAGE_RESPONSE_SIZE; else - skb_push(skb, wg->advanced_security_config.response_packet_junk_size); + skb_push(skb, wg->junk_size[MSGIDX_HANDSHAKE_RESPONSE]); } - if (skb->len == MESSAGE_COOKIE_REPLY_SIZE) { + if (skb->len == wg->junk_size[MSGIDX_HANDSHAKE_COOKIE] + MESSAGE_COOKIE_REPLY_SIZE) { + skb_pull(skb, wg->junk_size[MSGIDX_HANDSHAKE_COOKIE]); if (mh_validate(SKB_TYPE_LE32(skb), &wg->headers[MSGIDX_HANDSHAKE_COOKIE])) return MESSAGE_HANDSHAKE_COOKIE; + else + skb_push(skb, wg->junk_size[MSGIDX_HANDSHAKE_COOKIE]); } - if (skb->len >= MESSAGE_TRANSPORT_SIZE) { + if (skb->len >= wg->junk_size[MSGIDX_TRANSPORT] + MESSAGE_TRANSPORT_SIZE) { + skb_pull(skb, wg->junk_size[MSGIDX_TRANSPORT]); if (mh_validate(SKB_TYPE_LE32(skb), &wg->headers[MSGIDX_TRANSPORT])) return MESSAGE_TRANSPORT_SIZE; + else + skb_push(skb, wg->junk_size[MSGIDX_TRANSPORT]); } net_dbg_skb_ratelimited("%s: Unknown message from %pISpfsc encountered, packet dropped\n", diff --git a/src/send.c b/src/send.c index bbe6c9e..3cd4faa 100644 --- a/src/send.c +++ b/src/send.c @@ -56,7 +56,7 @@ static void wg_packet_send_handshake_initiation(struct wg_peer *peer) get_random_bytes(buffer, junk_packet_size); get_random_bytes(&ds, 1); - wg_socket_send_buffer_to_peer(peer, buffer, junk_packet_size, ds); + wg_socket_send_buffer_to_peer(peer, buffer, junk_packet_size, ds, 0); } kfree(buffer); @@ -69,13 +69,7 @@ static void wg_packet_send_handshake_initiation(struct wg_peer *peer) atomic64_set(&peer->last_sent_handshake, ktime_get_coarse_boottime_ns()); - if (wg->advanced_security_config.advanced_security && peer->advanced_security) { - wg_socket_send_junked_buffer_to_peer(peer, &packet, sizeof(packet), - HANDSHAKE_DSCP, wg->advanced_security_config.init_packet_junk_size); - } else { - wg_socket_send_buffer_to_peer(peer, &packet, sizeof(packet), - HANDSHAKE_DSCP); - } + wg_socket_send_buffer_to_peer(peer, &packet, sizeof(packet), HANDSHAKE_DSCP, wg->junk_size[MSGIDX_HANDSHAKE_INIT]); wg_timers_handshake_initiated(peer); } } @@ -138,16 +132,7 @@ void wg_packet_send_handshake_response(struct wg_peer *peer) wg_timers_any_authenticated_packet_sent(peer); atomic64_set(&peer->last_sent_handshake, ktime_get_coarse_boottime_ns()); - if (wg->advanced_security_config.advanced_security && peer->advanced_security) { - wg_socket_send_junked_buffer_to_peer(peer, &packet, - sizeof(packet), - HANDSHAKE_DSCP, - wg->advanced_security_config.response_packet_junk_size); - } else { - wg_socket_send_buffer_to_peer(peer, &packet, - sizeof(packet), - HANDSHAKE_DSCP); - } + wg_socket_send_buffer_to_peer(peer, &packet, sizeof(packet), HANDSHAKE_DSCP, wg->junk_size[MSGIDX_HANDSHAKE_RESPONSE]); } } } @@ -162,7 +147,7 @@ void wg_packet_send_handshake_cookie(struct wg_device *wg, wg->dev->name, initiating_skb); wg_cookie_message_create(&packet, initiating_skb, sender_index, &wg->cookie_checker, mh_genheader(&wg->headers[MSGIDX_HANDSHAKE_COOKIE])); - wg_socket_send_buffer_as_reply_to_skb(wg, initiating_skb, &packet, sizeof(packet)); + wg_socket_send_buffer_as_reply_to_skb(wg, initiating_skb, &packet, sizeof(packet), wg->junk_size[MSGIDX_HANDSHAKE_COOKIE]); } static void keep_key_fresh(struct wg_peer *peer) @@ -203,7 +188,7 @@ static unsigned int calculate_skb_padding(struct sk_buff *skb) return padded_size - last_unit; } -static bool encrypt_packet(u32 message_type, struct sk_buff *skb, struct noise_keypair *keypair +static bool encrypt_packet(u32 message_type, size_t junk_size, struct sk_buff *skb, struct noise_keypair *keypair COMPAT_MAYBE_SIMD_CONTEXT(simd_context_t *simd_context)) { unsigned int padding_len, plaintext_len, trailer_len; @@ -235,7 +220,7 @@ static bool encrypt_packet(u32 message_type, struct sk_buff *skb, struct noise_k /* Expand head section to have room for our header and the network * stack's headers. */ - if (unlikely(skb_cow_head(skb, DATA_PACKET_HEAD_ROOM) < 0)) + if (unlikely(skb_cow_head(skb, DATA_PACKET_HEAD_ROOM + junk_size) < 0)) return false; /* Finalize checksum calculation for the inner packet, if required. */ @@ -253,9 +238,11 @@ static bool encrypt_packet(u32 message_type, struct sk_buff *skb, struct noise_k header->counter = cpu_to_le64(PACKET_CB(skb)->nonce); pskb_put(skb, trailer, trailer_len); + get_random_bytes(skb_push(skb, junk_size), junk_size); + /* Now we can encrypt the scattergather segments */ sg_init_table(sg, num_frags); - if (skb_to_sgvec(skb, sg, sizeof(struct message_data), + if (skb_to_sgvec(skb, sg, sizeof(struct message_data) + junk_size, noise_encrypted_len(plaintext_len)) <= 0) return false; return chacha20poly1305_encrypt_sg_inplace(sg, plaintext_len, NULL, 0, @@ -349,7 +336,7 @@ void wg_packet_encrypt_worker(struct work_struct *work) if (likely(encrypt_packet( mh_genheader(&wg->headers[MSGIDX_TRANSPORT]), - wg->advanced_security_config.transport_packet_magic_header : MESSAGE_DATA, + wg->junk_size[MSGIDX_TRANSPORT], skb, PACKET_CB(first)->keypair COMPAT_MAYBE_SIMD_CONTEXT(&simd_context)))) { diff --git a/src/socket.c b/src/socket.c index 23fa6d1..e4a769c 100644 --- a/src/socket.c +++ b/src/socket.c @@ -187,38 +187,30 @@ int wg_socket_send_skb_to_peer(struct wg_peer *peer, struct sk_buff *skb, u8 ds) } int wg_socket_send_buffer_to_peer(struct wg_peer *peer, void *buffer, - size_t len, u8 ds) + size_t len, u8 ds, size_t junk_size) { - struct sk_buff *skb = alloc_skb(len + SKB_HEADER_LEN, GFP_ATOMIC); + void* junk; + struct sk_buff *skb = alloc_skb(len + junk_size + SKB_HEADER_LEN, GFP_ATOMIC); if (unlikely(!skb)) return -ENOMEM; skb_reserve(skb, SKB_HEADER_LEN); skb_set_inner_network_header(skb, 0); + junk = skb_put(skb, junk_size); + get_random_bytes(junk, junk_size); skb_put_data(skb, buffer, len); return wg_socket_send_skb_to_peer(peer, skb, ds); } -int wg_socket_send_junked_buffer_to_peer(struct wg_peer *peer, void *buffer, - size_t len, u8 ds, u16 junk_size) -{ - int ret; - void *new_buffer = kzalloc(len + junk_size, GFP_KERNEL); - get_random_bytes(new_buffer, junk_size); - memcpy(new_buffer + junk_size, buffer, len); - ret = wg_socket_send_buffer_to_peer(peer, new_buffer, len + junk_size, ds); - kfree(new_buffer); - return ret; -} - int wg_socket_send_buffer_as_reply_to_skb(struct wg_device *wg, struct sk_buff *in_skb, void *buffer, - size_t len) + size_t len, size_t junk_size) { int ret = 0; struct sk_buff *skb; struct endpoint endpoint; + void* junk; if (unlikely(!in_skb)) return -EINVAL; @@ -226,11 +218,13 @@ int wg_socket_send_buffer_as_reply_to_skb(struct wg_device *wg, if (unlikely(ret < 0)) return ret; - skb = alloc_skb(len + SKB_HEADER_LEN, GFP_ATOMIC); + skb = alloc_skb(len + junk_size + SKB_HEADER_LEN, GFP_ATOMIC); if (unlikely(!skb)) return -ENOMEM; skb_reserve(skb, SKB_HEADER_LEN); skb_set_inner_network_header(skb, 0); + junk = skb_put(skb, junk_size); + get_random_bytes(junk, junk_size); skb_put_data(skb, buffer, len); if (endpoint.addr.sa_family == AF_INET) diff --git a/src/socket.h b/src/socket.h index e4e3f96..82eb53e 100644 --- a/src/socket.h +++ b/src/socket.h @@ -15,14 +15,12 @@ int wg_socket_init(struct wg_device *wg, u16 port); void wg_socket_reinit(struct wg_device *wg, struct sock *new4, struct sock *new6); int wg_socket_send_buffer_to_peer(struct wg_peer *peer, void *data, - size_t len, u8 ds); -int wg_socket_send_junked_buffer_to_peer(struct wg_peer *peer, void *data, - size_t len, u8 ds, u16 junk_size); + size_t len, u8 ds, size_t junk_size); int wg_socket_send_skb_to_peer(struct wg_peer *peer, struct sk_buff *skb, u8 ds); int wg_socket_send_buffer_as_reply_to_skb(struct wg_device *wg, struct sk_buff *in_skb, - void *out_buffer, size_t len); + void *out_buffer, size_t len, size_t junk_size); int wg_socket_endpoint_from_skb(struct endpoint *endpoint, const struct sk_buff *skb);