chore: sync with mainstream wireguard

* apply changes from recent kernel
* extended compatibility layer
This commit is contained in:
Yaroslav Gurov
2025-09-24 00:16:24 +02:00
committed by Yaroslav Gurov
parent 3d1147e1fb
commit 9eb888d250
23 changed files with 269 additions and 101 deletions

View File

@@ -6,6 +6,8 @@
#include "allowedips.h"
#include "peer.h"
enum { MAX_ALLOWEDIPS_DEPTH = 129 };
static struct kmem_cache *node_cache;
static void swap_endian(u8 *dst, const u8 *src, u8 bits)
@@ -13,8 +15,8 @@ static void swap_endian(u8 *dst, const u8 *src, u8 bits)
if (bits == 32) {
*(u32 *)dst = be32_to_cpu(*(const __be32 *)src);
} else if (bits == 128) {
((u64 *)dst)[0] = be64_to_cpu(((const __be64 *)src)[0]);
((u64 *)dst)[1] = be64_to_cpu(((const __be64 *)src)[1]);
((u64 *)dst)[0] = get_unaligned_be64(src);
((u64 *)dst)[1] = get_unaligned_be64(src + 8);
}
}
@@ -40,7 +42,8 @@ static void push_rcu(struct allowedips_node **stack,
struct allowedips_node __rcu *p, unsigned int *len)
{
if (rcu_access_pointer(p)) {
WARN_ON(IS_ENABLED(DEBUG) && *len >= 128);
if (WARN_ON(IS_ENABLED(DEBUG) && *len >= MAX_ALLOWEDIPS_DEPTH))
return;
stack[(*len)++] = rcu_dereference_raw(p);
}
}
@@ -52,7 +55,7 @@ static void node_free_rcu(struct rcu_head *rcu)
static void root_free_rcu(struct rcu_head *rcu)
{
struct allowedips_node *node, *stack[128] = {
struct allowedips_node *node, *stack[MAX_ALLOWEDIPS_DEPTH] = {
container_of(rcu, struct allowedips_node, rcu) };
unsigned int len = 1;
@@ -65,7 +68,7 @@ static void root_free_rcu(struct rcu_head *rcu)
static void root_remove_peer_lists(struct allowedips_node *root)
{
struct allowedips_node *node, *stack[128] = { root };
struct allowedips_node *node, *stack[MAX_ALLOWEDIPS_DEPTH] = { root };
unsigned int len = 1;
while (len > 0 && (node = stack[--len])) {

View File

@@ -45,7 +45,7 @@ ccflags-y += -I$(kbuild-dir)/compat/udp_tunnel/include
amneziawg-y += compat/udp_tunnel/udp_tunnel.o
endif
ifeq ($(shell grep -s -F "int crypto_memneq" "$(srctree)/include/crypto/algapi.h"),)
ifeq ($(shell grep -s -F "int crypto_memneq" "$(srctree)/include/crypto/algapi.h")$(shell grep -s -F "int crypto_memneq" "$(srctree)/include/crypto/utils.h"),)
ccflags-y += -include $(kbuild-dir)/compat/memneq/include.h
amneziawg-y += compat/memneq/memneq.o
endif
@@ -107,3 +107,31 @@ endif
ifneq ($(shell grep -s -F "\#define LINUX_PACKAGE_ID \" Debian " "$(CURDIR)/include/generated/package.h"),)
ccflags-y += -DISDEBIAN
endif
ifeq ($(wildcard $(srctree)/include/crypto/blake2s.h),)
ccflags-y += -I$(kbuild-dir)/compat/crypto/blake2s/include
endif
ifeq ($(wildcard $(srctree)/include/crypto/chacha20poly1305.h),)
ccflags-y += -I$(kbuild-dir)/compat/crypto/chacha20poly1305/include
endif
ifeq ($(wildcard $(srctree)/include/crypto/utils.h),)
ccflags-y += -I$(kbuild-dir)/compat/crypto/utils/include
endif
ifeq ($(wildcard $(srctree)/include/crypto/curve25519.h),)
ccflags-y += -I$(kbuild-dir)/compat/crypto/curve25519/include
endif
ifeq ($(wildcard $(srctree)/include/linux/kstrtox.h),)
ccflags-y += -I$(kbuild-dir)/compat/kstrtox/include
endif
ifeq ($(wildcard $(srctree)/include/net/gso.h),)
ccflags-y += -I$(kbuild-dir)/compat/gso/include
endif
ifeq ($(wildcard $(srctree)/include/linux/sprintf.h),)
ccflags-y += -I$(kbuild-dir)/compat/sprintf/include
endif

View File

@@ -888,13 +888,14 @@ static inline void skb_mark_not_on_list(struct sk_buff *skb)
#endif
#endif
#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 200) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 20, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 19, 249)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 15, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 14, 285)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 10, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 9, 320))
#if (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 200) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 20, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 19, 249)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 15, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 14, 285)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 10, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 9, 320))) && LINUX_VERSION_CODE < KERNEL_VERSION(5, 10, 0)
#define COMPAT_INIT_CRYPTO
#define blake2s_init zinc_blake2s_init
#define blake2s_init_key zinc_blake2s_init_key
#define blake2s_update zinc_blake2s_update
#define blake2s_final zinc_blake2s_final
#endif
#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 5, 0)
#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 5, 0) && LINUX_VERSION_CODE < KERNEL_VERSION(5, 10, 0)
#define blake2s_hmac zinc_blake2s_hmac
#define chacha20 zinc_chacha20
#define hchacha20 zinc_hchacha20
@@ -939,6 +940,13 @@ static inline void skb_mark_not_on_list(struct sk_buff *skb)
#define chacha20_neon zinc_chacha20_neon
#endif
#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 10, 0)
#define COMPAT_CRYPTO_IS_ZINC
#define COMPAT_MAYBE_SIMD_CONTEXT(ctx) , ctx
#else
#define COMPAT_MAYBE_SIMD_CONTEXT(ctx)
#endif
#if LINUX_VERSION_CODE < KERNEL_VERSION(3, 19, 0) && !defined(ISRHEL7)
#include <linux/skbuff.h>
static inline int skb_ensure_writable(struct sk_buff *skb, int write_len)
@@ -1117,7 +1125,7 @@ static const struct header_ops ip_tunnel_header_ops = { .parse_protocol = ip_tun
#endif
#endif
#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 16, 0)
#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 15, 30)
#include <net/dst_cache.h>
struct dst_cache_pcpu {
unsigned long refresh_ts;
@@ -1192,4 +1200,77 @@ static inline void dst_cache_reset_now(struct dst_cache *dst_cache)
#define from_timer(var, callback_timer, timer_fieldname) container_of((struct timer_list *)callback_timer, typeof(*var), timer_fieldname)
#endif
#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 11, 0)
#include <net/flow.h>
#define flowi4_to_flowi_common(fl4) flowi4_to_flowi(fl4)
#define flowi6_to_flowi_common(fl4) flowi6_to_flowi(fl4)
#endif
#if LINUX_VERSION_CODE < KERNEL_VERSION(6, 6, 0)
#define genl_info_dump(cb) genl_dumpit_info(cb)
#endif
#if LINUX_VERSION_CODE < KERNEL_VERSION(6, 1, 84) && \
!(LINUX_VERSION_CODE >= KERNEL_VERSION(4, 19, 312) && LINUX_VERSION_CODE < KERNEL_VERSION(4, 20, 0)) && \
!(LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 274) && LINUX_VERSION_CODE < KERNEL_VERSION(5, 5, 0)) && \
!(LINUX_VERSION_CODE >= KERNEL_VERSION(5, 10, 215) && LINUX_VERSION_CODE < KERNEL_VERSION(5, 11, 0)) && \
!(LINUX_VERSION_CODE >= KERNEL_VERSION(5, 15, 154) && LINUX_VERSION_CODE < KERNEL_VERSION(5, 16, 0))
#define timer_delete_sync(timer) del_timer_sync(timer)
#endif
#if LINUX_VERSION_CODE < KERNEL_VERSION(6, 2, 0)
#include <linux/random.h>
static inline u32 get_random_u32_below(u32 ceil)
{
return get_random_u32() % ceil;
}
static inline u32 get_random_u32_inclusive(u32 floor, u32 ceil)
{
return floor + get_random_u32_below(ceil - floor + 1);
}
#endif
#if LINUX_VERSION_CODE < KERNEL_VERSION(6, 1, 0)
#define COMPAT_NETIF_HAS_WEIGHT
#endif
#if LINUX_VERSION_CODE >= KERNEL_VERSION(6, 1, 0)
#define COMPAT_GENL_HAS_RESV_START_OP
#endif
#if LINUX_VERSION_CODE < KERNEL_VERSION(6, 2, 0) && \
!(LINUX_VERSION_CODE >= KERNEL_VERSION(4, 19, 296) && LINUX_VERSION_CODE < KERNEL_VERSION(4, 20, 0)) && \
!(LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 229) && LINUX_VERSION_CODE < KERNEL_VERSION(5, 5, 0)) && \
!(LINUX_VERSION_CODE >= KERNEL_VERSION(5, 10, 163) && LINUX_VERSION_CODE < KERNEL_VERSION(5, 11, 0)) && \
!(LINUX_VERSION_CODE >= KERNEL_VERSION(5, 15, 86) && LINUX_VERSION_CODE < KERNEL_VERSION(5, 16, 0))
#undef DEV_STATS_INC
#define DEV_STATS_INC(DEV, FIELD) ++DEV->stats.FIELD
#undef DEV_STATS_ADD
#define DEV_STATS_ADD(DEV, FIELD, VAL) DEV->stats.FIELD += VAL
#endif
#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 17, 0)
#define COMPAT_SKB_HAS_SKB_START
#endif
#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 11, 0)
#define dev_get_tstats64 ip_tunnel_get_stats64
#endif
#if LINUX_VERSION_CODE >= KERNEL_VERSION(6, 12, 0)
#define COMPAT_NETDEV_HAS_LLTX_PARAM
#endif
#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 6, 0)
static inline void dev_sw_netstats_rx_add(struct net_device *dev, unsigned int len) {
struct pcpu_sw_netstats *tstats = get_cpu_ptr(dev->tstats);
u64_stats_update_begin(&tstats->syncp);
++tstats->rx_packets;
tstats->rx_bytes += len;
u64_stats_update_end(&tstats->syncp);
put_cpu_ptr(tstats);
}
#endif
#endif /* _WG_COMPAT_H */

View File

@@ -0,0 +1 @@
#include <zinc/blake2s.h>

View File

@@ -0,0 +1 @@
#include <zinc/chacha20poly1305.h>

View File

@@ -0,0 +1 @@
#include <zinc/curve25519.h>

View File

@@ -0,0 +1 @@
#include <crypto/algapi.h>

View File

@@ -0,0 +1,6 @@
#ifndef _AWG_COMPAT_NET_GSO
#define _AWG_COMPAT_NET_GSO
#include <linux/netdevice.h>
#endif

View File

@@ -0,0 +1 @@
#include <linux/kernel.h>

View File

@@ -0,0 +1 @@
#include <linux/kernel.h>

View File

@@ -10,11 +10,11 @@
#include "ratelimiter.h"
#include "timers.h"
#include <zinc/blake2s.h>
#include <zinc/chacha20poly1305.h>
#include <crypto/blake2s.h>
#include <crypto/chacha20poly1305.h>
#include <crypto/utils.h>
#include <net/ipv6.h>
#include <crypto/algapi.h>
void wg_cookie_checker_init(struct cookie_checker *checker,
struct wg_device *wg)

View File

@@ -20,6 +20,7 @@
#include <linux/icmp.h>
#include <linux/suspend.h>
#include <net/dst_metadata.h>
#include <net/gso.h>
#include <net/icmp.h>
#include <net/rtnetlink.h>
#include <net/ip_tunnels.h>
@@ -201,7 +202,7 @@ static netdev_tx_t wg_xmit(struct sk_buff *skb, struct net_device *dev)
*/
while (skb_queue_len(&peer->staged_packet_queue) > MAX_STAGED_PACKETS) {
dev_kfree_skb(__skb_dequeue(&peer->staged_packet_queue));
++dev->stats.tx_dropped;
DEV_STATS_INC(dev, tx_dropped);
}
skb_queue_splice_tail(&packets, &peer->staged_packet_queue);
spin_unlock_bh(&peer->staged_packet_queue.lock);
@@ -219,7 +220,7 @@ err_icmp:
else if (skb->protocol == htons(ETH_P_IPV6))
icmpv6_ndo_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_ADDR_UNREACH, 0);
err:
++dev->stats.tx_errors;
DEV_STATS_INC(dev, tx_errors);
kfree_skb(skb);
return ret;
}
@@ -228,7 +229,7 @@ static const struct net_device_ops netdev_ops = {
.ndo_open = wg_open,
.ndo_stop = wg_stop,
.ndo_start_xmit = wg_xmit,
.ndo_get_stats64 = ip_tunnel_get_stats64
.ndo_get_stats64 = dev_get_tstats64
};
static void wg_destruct(struct net_device *dev)
@@ -286,7 +287,11 @@ static void wg_setup(struct net_device *dev)
#else
dev->tx_queue_len = 0;
#endif
#ifdef COMPAT_NETDEV_HAS_LLTX_PARAM
dev->lltx = true;
#else
dev->features |= NETIF_F_LLTX;
#endif
dev->features |= WG_NETDEV_FEATURES;
dev->hw_features |= WG_NETDEV_FEATURES;
dev->hw_enc_features |= WG_NETDEV_FEATURES;

View File

@@ -11,21 +11,22 @@
#include "ratelimiter.h"
#include "netlink.h"
#include "uapi/wireguard.h"
#include "crypto/zinc.h"
#include <linux/init.h>
#include <linux/module.h>
#include <linux/genetlink.h>
#include <net/genetlink.h>
#include <net/rtnetlink.h>
static int __init wg_mod_init(void)
{
int ret;
#ifdef COMPAT_INIT_CRYPTO
if ((ret = chacha20_mod_init()) || (ret = poly1305_mod_init()) ||
(ret = chacha20poly1305_mod_init()) || (ret = blake2s_mod_init()) ||
(ret = curve25519_mod_init()))
return ret;
#endif
ret = wg_allowedips_slab_init();
if (ret < 0)

View File

@@ -6,9 +6,9 @@
#ifndef _WG_MESSAGES_H
#define _WG_MESSAGES_H
#include <zinc/curve25519.h>
#include <zinc/chacha20poly1305.h>
#include <zinc/blake2s.h>
#include <crypto/curve25519.h>
#include <crypto/chacha20poly1305.h>
#include <crypto/blake2s.h>
#include <linux/kernel.h>
#include <linux/param.h>

View File

@@ -13,7 +13,7 @@
#include <linux/if.h>
#include <net/genetlink.h>
#include <net/sock.h>
#include <crypto/algapi.h>
#include <crypto/utils.h>
#include <linux/random.h>
#include <linux/bitops.h>
@@ -356,8 +356,8 @@ get_peer(struct wg_peer *peer, struct sk_buff *skb, struct dump_ctx *ctx)
if (!allowedips_node)
goto no_allowedips;
if (!ctx->allowedips_seq)
ctx->allowedips_seq = peer->device->peer_allowedips.seq;
else if (ctx->allowedips_seq != peer->device->peer_allowedips.seq)
ctx->allowedips_seq = ctx->wg->peer_allowedips.seq;
else if (ctx->allowedips_seq != ctx->wg->peer_allowedips.seq)
goto no_allowedips;
allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS);
@@ -392,7 +392,7 @@ static int wg_get_device_start(struct netlink_callback *cb)
{
struct wg_device *wg;
wg = lookup_interface(genl_dumpit_info(cb)->attrs, cb->skb);
wg = lookup_interface(genl_info_dump(cb)->attrs, cb->skb);
if (IS_ERR(wg))
return PTR_ERR(wg);
DUMP_CTX(cb)->wg = wg;
@@ -892,7 +892,6 @@ struct genl_ops genl_ops[] = {
#ifdef COMPAT_CANNOT_INDIVIDUAL_NETLINK_OPS_POLICY
.policy = device_policy,
#endif
// Dummy comment to reduce fuzziness of patch file
.flags = GENL_UNS_ADMIN_PERM
}
};
@@ -910,6 +909,9 @@ __ro_after_init = {
.n_ops = ARRAY_SIZE(genl_ops),
#else
= {
#endif
#ifdef COMPAT_GENL_HAS_RESV_START_OP
.resv_start_op = WG_CMD_SET_DEVICE + 1,
#endif
.name = WG_GENL_NAME,
.version = WG_GENL_VERSION,

View File

@@ -17,7 +17,7 @@
#include <linux/bitmap.h>
#include <linux/scatterlist.h>
#include <linux/highmem.h>
#include <crypto/algapi.h>
#include <crypto/utils.h>
/* This implements Noise_IKpsk2:
*
@@ -304,6 +304,41 @@ void wg_noise_set_static_identity_private_key(
static_identity->static_public, private_key);
}
static void hmac(u8 *out, const u8 *in, const u8 *key, const size_t inlen, const size_t keylen)
{
struct blake2s_state state;
u8 x_key[BLAKE2S_BLOCK_SIZE] __aligned(__alignof__(u32)) = { 0 };
u8 i_hash[BLAKE2S_HASH_SIZE] __aligned(__alignof__(u32));
int i;
if (keylen > BLAKE2S_BLOCK_SIZE) {
blake2s_init(&state, BLAKE2S_HASH_SIZE);
blake2s_update(&state, key, keylen);
blake2s_final(&state, x_key);
} else
memcpy(x_key, key, keylen);
for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
x_key[i] ^= 0x36;
blake2s_init(&state, BLAKE2S_HASH_SIZE);
blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
blake2s_update(&state, in, inlen);
blake2s_final(&state, i_hash);
for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
x_key[i] ^= 0x5c ^ 0x36;
blake2s_init(&state, BLAKE2S_HASH_SIZE);
blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
blake2s_update(&state, i_hash, BLAKE2S_HASH_SIZE);
blake2s_final(&state, i_hash);
memcpy(out, i_hash, BLAKE2S_HASH_SIZE);
memzero_explicit(x_key, BLAKE2S_BLOCK_SIZE);
memzero_explicit(i_hash, BLAKE2S_HASH_SIZE);
}
/* This is Hugo Krawczyk's HKDF:
* - https://eprint.iacr.org/2010/264.pdf
* - https://tools.ietf.org/html/rfc5869
@@ -324,16 +359,14 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
((third_len || third_dst) && (!second_len || !second_dst))));
/* Extract entropy from data into secret */
blake2s_hmac(secret, data, chaining_key, BLAKE2S_HASH_SIZE, data_len,
NOISE_HASH_LEN);
hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
if (!first_dst || !first_len)
goto out;
/* Expand first key: key = secret, data = 0x1 */
output[0] = 1;
blake2s_hmac(output, output, secret, BLAKE2S_HASH_SIZE, 1,
BLAKE2S_HASH_SIZE);
hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
memcpy(first_dst, output, first_len);
if (!second_dst || !second_len)
@@ -341,8 +374,7 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
/* Expand second key: key = secret, data = first-key || 0x2 */
output[BLAKE2S_HASH_SIZE] = 2;
blake2s_hmac(output, output, secret, BLAKE2S_HASH_SIZE,
BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
memcpy(second_dst, output, second_len);
if (!third_dst || !third_len)
@@ -350,8 +382,7 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
/* Expand third key: key = secret, data = second-key || 0x3 */
output[BLAKE2S_HASH_SIZE] = 3;
blake2s_hmac(output, output, secret, BLAKE2S_HASH_SIZE,
BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
memcpy(third_dst, output, third_len);
out:

View File

@@ -55,8 +55,11 @@ struct wg_peer *wg_peer_create(struct wg_device *wg,
skb_queue_head_init(&peer->staged_packet_queue);
wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
set_bit(NAPI_STATE_NO_BUSY_POLL, &peer->napi.state);
netif_napi_add(wg->dev, &peer->napi, wg_packet_rx_poll,
NAPI_POLL_WEIGHT);
#ifdef COMPAT_NETIF_HAS_WEIGHT
netif_napi_add(wg->dev, &peer->napi, wg_packet_rx_poll, NAPI_POLL_WEIGHT);
#else
netif_napi_add(wg->dev, &peer->napi, wg_packet_rx_poll);
#endif
napi_enable(&peer->napi);
list_add_tail(&peer->peer_list, &wg->peer_list);
INIT_LIST_HEAD(&peer->allowedips_list);

View File

@@ -28,6 +28,7 @@ int wg_packet_queue_init(struct crypt_queue *queue, work_func_t function,
int ret;
memset(queue, 0, sizeof(*queue));
queue->last_cpu = -1;
ret = ptr_ring_init(&queue->ring, len, GFP_KERNEL);
if (ret)
return ret;

View File

@@ -75,20 +75,21 @@ static inline bool wg_check_packet_protocol(struct sk_buff *skb)
static inline void wg_reset_packet(struct sk_buff *skb, bool encapsulating)
{
const int pfmemalloc = skb->pfmemalloc;
u32 hash = skb->hash;
u8 l4_hash = skb->l4_hash;
u8 sw_hash = skb->sw_hash;
u32 hash = skb->hash;
skb_scrub_packet(skb, true);
#ifdef COMPAT_SKB_HAS_SKB_START
memset(&skb->headers_start, 0,
offsetof(struct sk_buff, headers_end) -
offsetof(struct sk_buff, headers_start));
skb->pfmemalloc = pfmemalloc;
#else
memset(&skb->headers, 0, sizeof(skb->headers));
#endif
if (encapsulating) {
skb->hash = hash;
skb->l4_hash = l4_hash;
skb->sw_hash = sw_hash;
skb->hash = hash;
}
skb->queue_mapping = 0;
skb->nohdr = 0;
@@ -111,7 +112,7 @@ static inline int wg_cpumask_choose_online(int *stored_cpu, unsigned int id)
{
unsigned int cpu = *stored_cpu, cpu_index, i;
if (unlikely(cpu == nr_cpumask_bits ||
if (unlikely(cpu >= nr_cpu_ids ||
!cpumask_test_cpu(cpu, cpu_online_mask))) {
cpu_index = id % cpumask_weight(cpu_online_mask);
cpu = cpumask_first(cpu_online_mask);
@@ -122,20 +123,17 @@ static inline int wg_cpumask_choose_online(int *stored_cpu, unsigned int id)
return cpu;
}
/* This function is racy, in the sense that next is unlocked, so it could return
* the same CPU twice. A race-free version of this would be to instead store an
* atomic sequence number, do an increment-and-return, and then iterate through
* every possible CPU until we get to that index -- choose_cpu. However that's
* a bit slower, and it doesn't seem like this potential race actually
* introduces any performance loss, so we live with it.
/* This function is racy, in the sense that it's called while last_cpu is
* unlocked, so it could return the same CPU twice. Adding locking or using
* atomic sequence numbers is slower though, and the consequences of racing are
* harmless, so live with it.
*/
static inline int wg_cpumask_next_online(int *next)
static inline int wg_cpumask_next_online(int *last_cpu)
{
int cpu = *next;
while (unlikely(!cpumask_test_cpu(cpu, cpu_online_mask)))
cpu = cpumask_next(cpu, cpu_online_mask) % nr_cpumask_bits;
*next = cpumask_next(cpu, cpu_online_mask) % nr_cpumask_bits;
int cpu = cpumask_next(READ_ONCE(*last_cpu), cpu_online_mask);
if (cpu >= nr_cpu_ids)
cpu = cpumask_first(cpu_online_mask);
WRITE_ONCE(*last_cpu, cpu);
return cpu;
}
@@ -164,7 +162,7 @@ static inline void wg_prev_queue_drop_peeked(struct prev_queue *queue)
static inline int wg_queue_enqueue_per_device_and_peer(
struct crypt_queue *device_queue, struct prev_queue *peer_queue,
struct sk_buff *skb, struct workqueue_struct *wq, int *next_cpu)
struct sk_buff *skb, struct workqueue_struct *wq)
{
int cpu;
@@ -178,7 +176,7 @@ static inline int wg_queue_enqueue_per_device_and_peer(
/* Then we queue it up in the device queue, which consumes the
* packet as soon as it can.
*/
cpu = wg_cpumask_next_online(next_cpu);
cpu = wg_cpumask_next_online(&device_queue->last_cpu);
if (unlikely(ptr_ring_produce_bh(&device_queue->ring, skb)))
return -EPIPE;
queue_work_on(cpu, wq, &per_cpu_ptr(device_queue->worker, cpu)->work);

View File

@@ -20,15 +20,8 @@
/* Must be called with bh disabled. */
static void update_rx_stats(struct wg_peer *peer, size_t len)
{
struct pcpu_sw_netstats *tstats =
get_cpu_ptr(peer->device->dev->tstats);
u64_stats_update_begin(&tstats->syncp);
++tstats->rx_packets;
tstats->rx_bytes += len;
dev_sw_netstats_rx_add(peer->device->dev, len);
peer->rx_bytes += len;
u64_stats_update_end(&tstats->syncp);
put_cpu_ptr(tstats);
}
static size_t validate_header_len(struct sk_buff *skb, struct wg_device *wg)
@@ -293,7 +286,7 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair,
if (unlikely(!READ_ONCE(keypair->receiving.is_valid) ||
wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) ||
keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) {
READ_ONCE(keypair->receiving_counter.counter) >= REJECT_AFTER_MESSAGES)) {
WRITE_ONCE(keypair->receiving.is_valid, false);
return false;
}
@@ -305,7 +298,7 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair,
* call skb_cow_data, so that there's no chance that data is removed
* from the skb, so that later we can extract the original endpoint.
*/
offset = skb->data - skb_network_header(skb);
offset = -skb_network_offset(skb);
skb_push(skb, offset);
num_frags = skb_cow_data(skb, 0, &trailer);
offset += sizeof(struct message_data);
@@ -318,9 +311,9 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair,
return false;
if (!chacha20poly1305_decrypt_sg_inplace(sg, skb->len, NULL, 0,
PACKET_CB(skb)->nonce,
keypair->receiving.key,
simd_context))
PACKET_CB(skb)->nonce,
keypair->receiving.key
COMPAT_MAYBE_SIMD_CONTEXT(simd_context)))
return false;
/* Another ugly situation of pushing and pulling the header so as to
@@ -361,7 +354,7 @@ static bool counter_validate(struct noise_replay_counter *counter, u64 their_cou
for (i = 1; i <= top; ++i)
counter->backtrack[(i + index_current) &
((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
counter->counter = their_counter;
WRITE_ONCE(counter->counter, their_counter);
}
index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
@@ -461,20 +454,20 @@ dishonest_packet_peer:
net_dbg_skb_ratelimited("%s: Packet has unallowed src IP (%pISc) from peer %llu (%pISpfsc)\n",
dev->name, skb, peer->internal_id,
&peer->endpoint.addr);
++dev->stats.rx_errors;
++dev->stats.rx_frame_errors;
DEV_STATS_INC(dev, rx_errors);
DEV_STATS_INC(dev, rx_frame_errors);
goto packet_processed;
dishonest_packet_type:
net_dbg_ratelimited("%s: Packet is neither ipv4 nor ipv6 from peer %llu (%pISpfsc)\n",
dev->name, peer->internal_id, &peer->endpoint.addr);
++dev->stats.rx_errors;
++dev->stats.rx_frame_errors;
DEV_STATS_INC(dev, rx_errors);
DEV_STATS_INC(dev, rx_frame_errors);
goto packet_processed;
dishonest_packet_size:
net_dbg_ratelimited("%s: Packet has incorrect size from peer %llu (%pISpfsc)\n",
dev->name, peer->internal_id, &peer->endpoint.addr);
++dev->stats.rx_errors;
++dev->stats.rx_length_errors;
DEV_STATS_INC(dev, rx_errors);
DEV_STATS_INC(dev, rx_length_errors);
goto packet_processed;
packet_processed:
dev_kfree_skb(skb);
@@ -508,7 +501,7 @@ int wg_packet_rx_poll(struct napi_struct *napi, int budget)
net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
peer->device->dev->name,
PACKET_CB(skb)->nonce,
keypair->receiving_counter.counter);
READ_ONCE(keypair->receiving_counter.counter));
goto next;
}
@@ -573,7 +566,7 @@ static void wg_packet_consume_data(struct wg_device *wg, struct sk_buff *skb)
goto err;
ret = wg_queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &peer->rx_queue, skb,
wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu);
wg->packet_crypt_wq);
if (unlikely(ret == -EPIPE))
wg_queue_enqueue_per_peer_rx(skb, PACKET_STATE_DEAD);
if (likely(!ret || ret == -EPIPE)) {

View File

@@ -3,6 +3,7 @@
* Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*/
#include "compat/compat.h"
#include "queueing.h"
#include "timers.h"
#include "device.h"
@@ -219,8 +220,8 @@ static unsigned int calculate_skb_padding(struct sk_buff *skb)
return padded_size - last_unit;
}
static bool encrypt_packet(u32 message_type, struct sk_buff *skb, struct noise_keypair *keypair,
simd_context_t *simd_context)
static bool encrypt_packet(u32 message_type, struct sk_buff *skb, struct noise_keypair *keypair
COMPAT_MAYBE_SIMD_CONTEXT(simd_context_t *simd_context))
{
unsigned int padding_len, plaintext_len, trailer_len;
struct scatterlist sg[MAX_SKB_FRAGS + 8];
@@ -276,15 +277,15 @@ static bool encrypt_packet(u32 message_type, struct sk_buff *skb, struct noise_k
return false;
return chacha20poly1305_encrypt_sg_inplace(sg, plaintext_len, NULL, 0,
PACKET_CB(skb)->nonce,
keypair->sending.key,
simd_context);
keypair->sending.key
COMPAT_MAYBE_SIMD_CONTEXT(simd_context));
}
void wg_packet_send_keepalive(struct wg_peer *peer)
{
struct sk_buff *skb;
if (skb_queue_empty(&peer->staged_packet_queue)) {
if (skb_queue_empty_lockless(&peer->staged_packet_queue)) {
skb = alloc_skb(DATA_PACKET_HEAD_ROOM + MESSAGE_MINIMUM_LENGTH,
GFP_ATOMIC);
if (unlikely(!skb))
@@ -352,9 +353,11 @@ void wg_packet_encrypt_worker(struct work_struct *work)
work)->ptr;
struct sk_buff *first, *skb, *next;
struct wg_device *wg;
simd_context_t simd_context;
#ifdef COMPAT_CRYPTO_IS_ZINC
simd_context_t simd_context;
simd_get(&simd_context);
#endif
while ((first = ptr_ring_consume_bh(&queue->ring)) != NULL) {
enum packet_state state = PACKET_STATE_CRYPTED;
@@ -364,8 +367,8 @@ void wg_packet_encrypt_worker(struct work_struct *work)
if (likely(encrypt_packet(PACKET_PEER(first)->advanced_security ?
wg->advanced_security_config.transport_packet_magic_header : MESSAGE_DATA,
skb,
PACKET_CB(first)->keypair,
&simd_context))) {
PACKET_CB(first)->keypair
COMPAT_MAYBE_SIMD_CONTEXT(&simd_context)))) {
wg_reset_packet(skb, true);
} else {
state = PACKET_STATE_DEAD;
@@ -374,9 +377,15 @@ void wg_packet_encrypt_worker(struct work_struct *work)
}
wg_queue_enqueue_per_peer_tx(first, state);
#ifdef COMPAT_CRYPTO_IS_ZINC
simd_relax(&simd_context);
#endif
if (need_resched())
cond_resched();
}
#ifdef COMPAT_CRYPTO_IS_ZINC
simd_put(&simd_context);
#endif
}
static void wg_packet_create_data(struct wg_peer *peer, struct sk_buff *first)
@@ -389,7 +398,7 @@ static void wg_packet_create_data(struct wg_peer *peer, struct sk_buff *first)
goto err;
ret = wg_queue_enqueue_per_device_and_peer(&wg->encrypt_queue, &peer->tx_queue, first,
wg->packet_crypt_wq, &wg->encrypt_queue.last_cpu);
wg->packet_crypt_wq);
if (unlikely(ret == -EPIPE))
wg_queue_enqueue_per_peer_tx(first, PACKET_STATE_DEAD);
err:
@@ -404,7 +413,8 @@ err:
void wg_packet_purge_staged_packets(struct wg_peer *peer)
{
spin_lock_bh(&peer->staged_packet_queue.lock);
peer->device->dev->stats.tx_dropped += peer->staged_packet_queue.qlen;
DEV_STATS_ADD(peer->device->dev, tx_dropped,
peer->staged_packet_queue.qlen);
__skb_queue_purge(&peer->staged_packet_queue);
spin_unlock_bh(&peer->staged_packet_queue.lock);
}

View File

@@ -49,7 +49,7 @@ static int send4(struct wg_device *wg, struct sk_buff *skb,
rt = dst_cache_get_ip4(cache, &fl.saddr);
if (!rt) {
security_sk_classify_flow(sock, flowi4_to_flowi(&fl));
security_sk_classify_flow(sock, flowi4_to_flowi_common(&fl));
if (unlikely(!inet_confirm_addr(sock_net(sock), NULL, 0,
fl.saddr, RT_SCOPE_HOST))) {
endpoint->src4.s_addr = 0;
@@ -129,7 +129,7 @@ static int send6(struct wg_device *wg, struct sk_buff *skb,
dst = dst_cache_get_ip6(cache, &fl.saddr);
if (!dst) {
security_sk_classify_flow(sock, flowi6_to_flowi(&fl));
security_sk_classify_flow(sock, flowi6_to_flowi_common(&fl));
if (unlikely(!ipv6_addr_any(&fl.saddr) &&
!ipv6_chk_addr(sock_net(sock), &fl.saddr, NULL, 0))) {
endpoint->src6 = fl.saddr = in6addr_any;

View File

@@ -46,7 +46,7 @@ static void wg_expired_retransmit_handshake(struct timer_list *timer)
if (peer->timer_handshake_attempts > MAX_TIMER_HANDSHAKES) {
pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d attempts, giving up\n",
peer->device->dev->name, peer->internal_id,
&peer->endpoint.addr, MAX_TIMER_HANDSHAKES + 2);
&peer->endpoint.addr, (int)MAX_TIMER_HANDSHAKES + 2);
del_timer(&peer->timer_send_keepalive);
/* We drop all packets without a keypair and don't try again,
@@ -64,7 +64,7 @@ static void wg_expired_retransmit_handshake(struct timer_list *timer)
++peer->timer_handshake_attempts;
pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d seconds, retrying (try %d)\n",
peer->device->dev->name, peer->internal_id,
&peer->endpoint.addr, REKEY_TIMEOUT,
&peer->endpoint.addr, (int)REKEY_TIMEOUT,
peer->timer_handshake_attempts + 1);
/* We clear the endpoint address src address, in case this is
@@ -94,7 +94,7 @@ static void wg_expired_new_handshake(struct timer_list *timer)
pr_debug("%s: Retrying handshake with peer %llu (%pISpfsc) because we stopped hearing back after %d seconds\n",
peer->device->dev->name, peer->internal_id,
&peer->endpoint.addr, KEEPALIVE_TIMEOUT + REKEY_TIMEOUT);
&peer->endpoint.addr, (int)(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT));
/* We clear the endpoint address src address, in case this is the cause
* of trouble.
*/
@@ -126,7 +126,7 @@ static void wg_queued_expired_zero_key_material(struct work_struct *work)
pr_debug("%s: Zeroing out all keys for peer %llu (%pISpfsc), since we haven't received a new one in %d seconds\n",
peer->device->dev->name, peer->internal_id,
&peer->endpoint.addr, REJECT_AFTER_TIME * 3);
&peer->endpoint.addr, (int)REJECT_AFTER_TIME * 3);
wg_noise_handshake_clear(&peer->handshake);
wg_noise_keypairs_clear(&peer->keypairs);
wg_peer_put(peer);
@@ -147,7 +147,7 @@ void wg_timers_data_sent(struct wg_peer *peer)
if (!timer_pending(&peer->timer_new_handshake))
mod_peer_timer(peer, &peer->timer_new_handshake,
jiffies + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * HZ +
prandom_u32_max(REKEY_TIMEOUT_JITTER_MAX_JIFFIES));
get_random_u32_below(REKEY_TIMEOUT_JITTER_MAX_JIFFIES));
}
/* Should be called after an authenticated data packet is received. */
@@ -183,7 +183,7 @@ void wg_timers_handshake_initiated(struct wg_peer *peer)
{
mod_peer_timer(peer, &peer->timer_retransmit_handshake,
jiffies + REKEY_TIMEOUT * HZ +
prandom_u32_max(REKEY_TIMEOUT_JITTER_MAX_JIFFIES));
get_random_u32_below(REKEY_TIMEOUT_JITTER_MAX_JIFFIES));
}
/* Should be called after a handshake response message is received and processed
@@ -234,10 +234,10 @@ void wg_timers_init(struct wg_peer *peer)
void wg_timers_stop(struct wg_peer *peer)
{
del_timer_sync(&peer->timer_retransmit_handshake);
del_timer_sync(&peer->timer_send_keepalive);
del_timer_sync(&peer->timer_new_handshake);
del_timer_sync(&peer->timer_zero_key_material);
del_timer_sync(&peer->timer_persistent_keepalive);
timer_delete_sync(&peer->timer_retransmit_handshake);
timer_delete_sync(&peer->timer_send_keepalive);
timer_delete_sync(&peer->timer_new_handshake);
timer_delete_sync(&peer->timer_zero_key_material);
timer_delete_sync(&peer->timer_persistent_keepalive);
flush_work(&peer->clear_peer_work);
}