diff --git a/src/allowedips.c b/src/allowedips.c index 9a4c8ff..4b85282 100644 --- a/src/allowedips.c +++ b/src/allowedips.c @@ -6,6 +6,8 @@ #include "allowedips.h" #include "peer.h" +enum { MAX_ALLOWEDIPS_DEPTH = 129 }; + static struct kmem_cache *node_cache; static void swap_endian(u8 *dst, const u8 *src, u8 bits) @@ -13,8 +15,8 @@ static void swap_endian(u8 *dst, const u8 *src, u8 bits) if (bits == 32) { *(u32 *)dst = be32_to_cpu(*(const __be32 *)src); } else if (bits == 128) { - ((u64 *)dst)[0] = be64_to_cpu(((const __be64 *)src)[0]); - ((u64 *)dst)[1] = be64_to_cpu(((const __be64 *)src)[1]); + ((u64 *)dst)[0] = get_unaligned_be64(src); + ((u64 *)dst)[1] = get_unaligned_be64(src + 8); } } @@ -40,7 +42,8 @@ static void push_rcu(struct allowedips_node **stack, struct allowedips_node __rcu *p, unsigned int *len) { if (rcu_access_pointer(p)) { - WARN_ON(IS_ENABLED(DEBUG) && *len >= 128); + if (WARN_ON(IS_ENABLED(DEBUG) && *len >= MAX_ALLOWEDIPS_DEPTH)) + return; stack[(*len)++] = rcu_dereference_raw(p); } } @@ -52,7 +55,7 @@ static void node_free_rcu(struct rcu_head *rcu) static void root_free_rcu(struct rcu_head *rcu) { - struct allowedips_node *node, *stack[128] = { + struct allowedips_node *node, *stack[MAX_ALLOWEDIPS_DEPTH] = { container_of(rcu, struct allowedips_node, rcu) }; unsigned int len = 1; @@ -65,7 +68,7 @@ static void root_free_rcu(struct rcu_head *rcu) static void root_remove_peer_lists(struct allowedips_node *root) { - struct allowedips_node *node, *stack[128] = { root }; + struct allowedips_node *node, *stack[MAX_ALLOWEDIPS_DEPTH] = { root }; unsigned int len = 1; while (len > 0 && (node = stack[--len])) { diff --git a/src/compat/Kbuild.include b/src/compat/Kbuild.include index cc5643e..4c12d29 100644 --- a/src/compat/Kbuild.include +++ b/src/compat/Kbuild.include @@ -45,7 +45,7 @@ ccflags-y += -I$(kbuild-dir)/compat/udp_tunnel/include amneziawg-y += compat/udp_tunnel/udp_tunnel.o endif -ifeq ($(shell grep -s -F "int crypto_memneq" "$(srctree)/include/crypto/algapi.h"),) +ifeq ($(shell grep -s -F "int crypto_memneq" "$(srctree)/include/crypto/algapi.h")$(shell grep -s -F "int crypto_memneq" "$(srctree)/include/crypto/utils.h"),) ccflags-y += -include $(kbuild-dir)/compat/memneq/include.h amneziawg-y += compat/memneq/memneq.o endif @@ -107,3 +107,31 @@ endif ifneq ($(shell grep -s -F "\#define LINUX_PACKAGE_ID \" Debian " "$(CURDIR)/include/generated/package.h"),) ccflags-y += -DISDEBIAN endif + +ifeq ($(wildcard $(srctree)/include/crypto/blake2s.h),) +ccflags-y += -I$(kbuild-dir)/compat/crypto/blake2s/include +endif + +ifeq ($(wildcard $(srctree)/include/crypto/chacha20poly1305.h),) +ccflags-y += -I$(kbuild-dir)/compat/crypto/chacha20poly1305/include +endif + +ifeq ($(wildcard $(srctree)/include/crypto/utils.h),) +ccflags-y += -I$(kbuild-dir)/compat/crypto/utils/include +endif + +ifeq ($(wildcard $(srctree)/include/crypto/curve25519.h),) +ccflags-y += -I$(kbuild-dir)/compat/crypto/curve25519/include +endif + +ifeq ($(wildcard $(srctree)/include/linux/kstrtox.h),) +ccflags-y += -I$(kbuild-dir)/compat/kstrtox/include +endif + +ifeq ($(wildcard $(srctree)/include/net/gso.h),) +ccflags-y += -I$(kbuild-dir)/compat/gso/include +endif + +ifeq ($(wildcard $(srctree)/include/linux/sprintf.h),) +ccflags-y += -I$(kbuild-dir)/compat/sprintf/include +endif diff --git a/src/compat/compat.h b/src/compat/compat.h index 6eda14a..1a6913a 100644 --- a/src/compat/compat.h +++ b/src/compat/compat.h @@ -888,13 +888,14 @@ static inline void skb_mark_not_on_list(struct sk_buff *skb) #endif #endif -#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 200) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 20, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 19, 249)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 15, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 14, 285)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 10, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 9, 320)) +#if (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 200) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 20, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 19, 249)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 15, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 14, 285)) || (LINUX_VERSION_CODE < KERNEL_VERSION(4, 10, 0) && LINUX_VERSION_CODE >= KERNEL_VERSION(4, 9, 320))) && LINUX_VERSION_CODE < KERNEL_VERSION(5, 10, 0) +#define COMPAT_INIT_CRYPTO #define blake2s_init zinc_blake2s_init #define blake2s_init_key zinc_blake2s_init_key #define blake2s_update zinc_blake2s_update #define blake2s_final zinc_blake2s_final #endif -#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 5, 0) +#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 5, 0) && LINUX_VERSION_CODE < KERNEL_VERSION(5, 10, 0) #define blake2s_hmac zinc_blake2s_hmac #define chacha20 zinc_chacha20 #define hchacha20 zinc_hchacha20 @@ -939,6 +940,13 @@ static inline void skb_mark_not_on_list(struct sk_buff *skb) #define chacha20_neon zinc_chacha20_neon #endif +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 10, 0) +#define COMPAT_CRYPTO_IS_ZINC +#define COMPAT_MAYBE_SIMD_CONTEXT(ctx) , ctx +#else +#define COMPAT_MAYBE_SIMD_CONTEXT(ctx) +#endif + #if LINUX_VERSION_CODE < KERNEL_VERSION(3, 19, 0) && !defined(ISRHEL7) #include static inline int skb_ensure_writable(struct sk_buff *skb, int write_len) @@ -1117,7 +1125,7 @@ static const struct header_ops ip_tunnel_header_ops = { .parse_protocol = ip_tun #endif #endif -#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 16, 0) +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 15, 30) #include struct dst_cache_pcpu { unsigned long refresh_ts; @@ -1192,4 +1200,77 @@ static inline void dst_cache_reset_now(struct dst_cache *dst_cache) #define from_timer(var, callback_timer, timer_fieldname) container_of((struct timer_list *)callback_timer, typeof(*var), timer_fieldname) #endif +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 11, 0) +#include +#define flowi4_to_flowi_common(fl4) flowi4_to_flowi(fl4) +#define flowi6_to_flowi_common(fl4) flowi6_to_flowi(fl4) +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(6, 6, 0) +#define genl_info_dump(cb) genl_dumpit_info(cb) +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(6, 1, 84) && \ + !(LINUX_VERSION_CODE >= KERNEL_VERSION(4, 19, 312) && LINUX_VERSION_CODE < KERNEL_VERSION(4, 20, 0)) && \ + !(LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 274) && LINUX_VERSION_CODE < KERNEL_VERSION(5, 5, 0)) && \ + !(LINUX_VERSION_CODE >= KERNEL_VERSION(5, 10, 215) && LINUX_VERSION_CODE < KERNEL_VERSION(5, 11, 0)) && \ + !(LINUX_VERSION_CODE >= KERNEL_VERSION(5, 15, 154) && LINUX_VERSION_CODE < KERNEL_VERSION(5, 16, 0)) +#define timer_delete_sync(timer) del_timer_sync(timer) +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(6, 2, 0) +#include +static inline u32 get_random_u32_below(u32 ceil) +{ + return get_random_u32() % ceil; +} +static inline u32 get_random_u32_inclusive(u32 floor, u32 ceil) +{ + return floor + get_random_u32_below(ceil - floor + 1); +} +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(6, 1, 0) +#define COMPAT_NETIF_HAS_WEIGHT +#endif + +#if LINUX_VERSION_CODE >= KERNEL_VERSION(6, 1, 0) +#define COMPAT_GENL_HAS_RESV_START_OP +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(6, 2, 0) && \ + !(LINUX_VERSION_CODE >= KERNEL_VERSION(4, 19, 296) && LINUX_VERSION_CODE < KERNEL_VERSION(4, 20, 0)) && \ + !(LINUX_VERSION_CODE >= KERNEL_VERSION(5, 4, 229) && LINUX_VERSION_CODE < KERNEL_VERSION(5, 5, 0)) && \ + !(LINUX_VERSION_CODE >= KERNEL_VERSION(5, 10, 163) && LINUX_VERSION_CODE < KERNEL_VERSION(5, 11, 0)) && \ + !(LINUX_VERSION_CODE >= KERNEL_VERSION(5, 15, 86) && LINUX_VERSION_CODE < KERNEL_VERSION(5, 16, 0)) +#undef DEV_STATS_INC +#define DEV_STATS_INC(DEV, FIELD) ++DEV->stats.FIELD +#undef DEV_STATS_ADD +#define DEV_STATS_ADD(DEV, FIELD, VAL) DEV->stats.FIELD += VAL +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 17, 0) +#define COMPAT_SKB_HAS_SKB_START +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 11, 0) +#define dev_get_tstats64 ip_tunnel_get_stats64 +#endif + +#if LINUX_VERSION_CODE >= KERNEL_VERSION(6, 12, 0) +#define COMPAT_NETDEV_HAS_LLTX_PARAM +#endif + +#if LINUX_VERSION_CODE < KERNEL_VERSION(5, 6, 0) +static inline void dev_sw_netstats_rx_add(struct net_device *dev, unsigned int len) { + struct pcpu_sw_netstats *tstats = get_cpu_ptr(dev->tstats); + + u64_stats_update_begin(&tstats->syncp); + ++tstats->rx_packets; + tstats->rx_bytes += len; + u64_stats_update_end(&tstats->syncp); + put_cpu_ptr(tstats); +} +#endif + #endif /* _WG_COMPAT_H */ diff --git a/src/compat/crypto/blake2s/include/crypto/blake2s.h b/src/compat/crypto/blake2s/include/crypto/blake2s.h new file mode 100644 index 0000000..1c19b09 --- /dev/null +++ b/src/compat/crypto/blake2s/include/crypto/blake2s.h @@ -0,0 +1 @@ +#include diff --git a/src/compat/crypto/chacha20poly1305/include/crypto/chacha20poly1305.h b/src/compat/crypto/chacha20poly1305/include/crypto/chacha20poly1305.h new file mode 100644 index 0000000..db4cff2 --- /dev/null +++ b/src/compat/crypto/chacha20poly1305/include/crypto/chacha20poly1305.h @@ -0,0 +1 @@ +#include diff --git a/src/compat/crypto/curve25519/include/crypto/curve25519.h b/src/compat/crypto/curve25519/include/crypto/curve25519.h new file mode 100644 index 0000000..e74c054 --- /dev/null +++ b/src/compat/crypto/curve25519/include/crypto/curve25519.h @@ -0,0 +1 @@ +#include diff --git a/src/compat/crypto/utils/include/crypto/utils.h b/src/compat/crypto/utils/include/crypto/utils.h new file mode 100644 index 0000000..c5c6826 --- /dev/null +++ b/src/compat/crypto/utils/include/crypto/utils.h @@ -0,0 +1 @@ +#include diff --git a/src/compat/gso/include/net/gso.h b/src/compat/gso/include/net/gso.h new file mode 100644 index 0000000..36e2e07 --- /dev/null +++ b/src/compat/gso/include/net/gso.h @@ -0,0 +1,6 @@ +#ifndef _AWG_COMPAT_NET_GSO +#define _AWG_COMPAT_NET_GSO + +#include + +#endif \ No newline at end of file diff --git a/src/compat/kstrtox/include/linux/kstrtox.h b/src/compat/kstrtox/include/linux/kstrtox.h new file mode 100644 index 0000000..ae7943c --- /dev/null +++ b/src/compat/kstrtox/include/linux/kstrtox.h @@ -0,0 +1 @@ +#include diff --git a/src/compat/sprintf/include/linux/sprintf.h b/src/compat/sprintf/include/linux/sprintf.h new file mode 100644 index 0000000..b5dd3e8 --- /dev/null +++ b/src/compat/sprintf/include/linux/sprintf.h @@ -0,0 +1 @@ +#include \ No newline at end of file diff --git a/src/cookie.c b/src/cookie.c index 3120094..3816cf8 100644 --- a/src/cookie.c +++ b/src/cookie.c @@ -10,11 +10,11 @@ #include "ratelimiter.h" #include "timers.h" -#include -#include +#include +#include +#include #include -#include void wg_cookie_checker_init(struct cookie_checker *checker, struct wg_device *wg) diff --git a/src/device.c b/src/device.c index 0a32dad..97e39ca 100644 --- a/src/device.c +++ b/src/device.c @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -201,7 +202,7 @@ static netdev_tx_t wg_xmit(struct sk_buff *skb, struct net_device *dev) */ while (skb_queue_len(&peer->staged_packet_queue) > MAX_STAGED_PACKETS) { dev_kfree_skb(__skb_dequeue(&peer->staged_packet_queue)); - ++dev->stats.tx_dropped; + DEV_STATS_INC(dev, tx_dropped); } skb_queue_splice_tail(&packets, &peer->staged_packet_queue); spin_unlock_bh(&peer->staged_packet_queue.lock); @@ -219,7 +220,7 @@ err_icmp: else if (skb->protocol == htons(ETH_P_IPV6)) icmpv6_ndo_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_ADDR_UNREACH, 0); err: - ++dev->stats.tx_errors; + DEV_STATS_INC(dev, tx_errors); kfree_skb(skb); return ret; } @@ -228,7 +229,7 @@ static const struct net_device_ops netdev_ops = { .ndo_open = wg_open, .ndo_stop = wg_stop, .ndo_start_xmit = wg_xmit, - .ndo_get_stats64 = ip_tunnel_get_stats64 + .ndo_get_stats64 = dev_get_tstats64 }; static void wg_destruct(struct net_device *dev) @@ -286,7 +287,11 @@ static void wg_setup(struct net_device *dev) #else dev->tx_queue_len = 0; #endif +#ifdef COMPAT_NETDEV_HAS_LLTX_PARAM + dev->lltx = true; +#else dev->features |= NETIF_F_LLTX; +#endif dev->features |= WG_NETDEV_FEATURES; dev->hw_features |= WG_NETDEV_FEATURES; dev->hw_enc_features |= WG_NETDEV_FEATURES; diff --git a/src/main.c b/src/main.c index c9c7057..6178b5c 100644 --- a/src/main.c +++ b/src/main.c @@ -11,21 +11,22 @@ #include "ratelimiter.h" #include "netlink.h" #include "uapi/wireguard.h" -#include "crypto/zinc.h" #include #include -#include +#include #include static int __init wg_mod_init(void) { int ret; +#ifdef COMPAT_INIT_CRYPTO if ((ret = chacha20_mod_init()) || (ret = poly1305_mod_init()) || (ret = chacha20poly1305_mod_init()) || (ret = blake2s_mod_init()) || (ret = curve25519_mod_init())) return ret; +#endif ret = wg_allowedips_slab_init(); if (ret < 0) diff --git a/src/messages.h b/src/messages.h index aa9f845..65b9ad1 100644 --- a/src/messages.h +++ b/src/messages.h @@ -6,9 +6,9 @@ #ifndef _WG_MESSAGES_H #define _WG_MESSAGES_H -#include -#include -#include +#include +#include +#include #include #include diff --git a/src/netlink.c b/src/netlink.c index 29049ba..8e8789f 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -13,7 +13,7 @@ #include #include #include -#include +#include #include #include @@ -356,8 +356,8 @@ get_peer(struct wg_peer *peer, struct sk_buff *skb, struct dump_ctx *ctx) if (!allowedips_node) goto no_allowedips; if (!ctx->allowedips_seq) - ctx->allowedips_seq = peer->device->peer_allowedips.seq; - else if (ctx->allowedips_seq != peer->device->peer_allowedips.seq) + ctx->allowedips_seq = ctx->wg->peer_allowedips.seq; + else if (ctx->allowedips_seq != ctx->wg->peer_allowedips.seq) goto no_allowedips; allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS); @@ -392,7 +392,7 @@ static int wg_get_device_start(struct netlink_callback *cb) { struct wg_device *wg; - wg = lookup_interface(genl_dumpit_info(cb)->attrs, cb->skb); + wg = lookup_interface(genl_info_dump(cb)->attrs, cb->skb); if (IS_ERR(wg)) return PTR_ERR(wg); DUMP_CTX(cb)->wg = wg; @@ -892,7 +892,6 @@ struct genl_ops genl_ops[] = { #ifdef COMPAT_CANNOT_INDIVIDUAL_NETLINK_OPS_POLICY .policy = device_policy, #endif - // Dummy comment to reduce fuzziness of patch file .flags = GENL_UNS_ADMIN_PERM } }; @@ -910,6 +909,9 @@ __ro_after_init = { .n_ops = ARRAY_SIZE(genl_ops), #else = { +#endif +#ifdef COMPAT_GENL_HAS_RESV_START_OP + .resv_start_op = WG_CMD_SET_DEVICE + 1, #endif .name = WG_GENL_NAME, .version = WG_GENL_VERSION, diff --git a/src/noise.c b/src/noise.c index e27d49b..7de85de 100644 --- a/src/noise.c +++ b/src/noise.c @@ -17,7 +17,7 @@ #include #include #include -#include +#include /* This implements Noise_IKpsk2: * @@ -304,6 +304,41 @@ void wg_noise_set_static_identity_private_key( static_identity->static_public, private_key); } +static void hmac(u8 *out, const u8 *in, const u8 *key, const size_t inlen, const size_t keylen) +{ + struct blake2s_state state; + u8 x_key[BLAKE2S_BLOCK_SIZE] __aligned(__alignof__(u32)) = { 0 }; + u8 i_hash[BLAKE2S_HASH_SIZE] __aligned(__alignof__(u32)); + int i; + + if (keylen > BLAKE2S_BLOCK_SIZE) { + blake2s_init(&state, BLAKE2S_HASH_SIZE); + blake2s_update(&state, key, keylen); + blake2s_final(&state, x_key); + } else + memcpy(x_key, key, keylen); + + for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i) + x_key[i] ^= 0x36; + + blake2s_init(&state, BLAKE2S_HASH_SIZE); + blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE); + blake2s_update(&state, in, inlen); + blake2s_final(&state, i_hash); + + for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i) + x_key[i] ^= 0x5c ^ 0x36; + + blake2s_init(&state, BLAKE2S_HASH_SIZE); + blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE); + blake2s_update(&state, i_hash, BLAKE2S_HASH_SIZE); + blake2s_final(&state, i_hash); + + memcpy(out, i_hash, BLAKE2S_HASH_SIZE); + memzero_explicit(x_key, BLAKE2S_BLOCK_SIZE); + memzero_explicit(i_hash, BLAKE2S_HASH_SIZE); +} + /* This is Hugo Krawczyk's HKDF: * - https://eprint.iacr.org/2010/264.pdf * - https://tools.ietf.org/html/rfc5869 @@ -324,16 +359,14 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, ((third_len || third_dst) && (!second_len || !second_dst)))); /* Extract entropy from data into secret */ - blake2s_hmac(secret, data, chaining_key, BLAKE2S_HASH_SIZE, data_len, - NOISE_HASH_LEN); + hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN); if (!first_dst || !first_len) goto out; /* Expand first key: key = secret, data = 0x1 */ output[0] = 1; - blake2s_hmac(output, output, secret, BLAKE2S_HASH_SIZE, 1, - BLAKE2S_HASH_SIZE); + hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE); memcpy(first_dst, output, first_len); if (!second_dst || !second_len) @@ -341,8 +374,7 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, /* Expand second key: key = secret, data = first-key || 0x2 */ output[BLAKE2S_HASH_SIZE] = 2; - blake2s_hmac(output, output, secret, BLAKE2S_HASH_SIZE, - BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE); + hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE); memcpy(second_dst, output, second_len); if (!third_dst || !third_len) @@ -350,8 +382,7 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, /* Expand third key: key = secret, data = second-key || 0x3 */ output[BLAKE2S_HASH_SIZE] = 3; - blake2s_hmac(output, output, secret, BLAKE2S_HASH_SIZE, - BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE); + hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE); memcpy(third_dst, output, third_len); out: diff --git a/src/peer.c b/src/peer.c index 557dc85..b3370e3 100644 --- a/src/peer.c +++ b/src/peer.c @@ -55,8 +55,11 @@ struct wg_peer *wg_peer_create(struct wg_device *wg, skb_queue_head_init(&peer->staged_packet_queue); wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake); set_bit(NAPI_STATE_NO_BUSY_POLL, &peer->napi.state); - netif_napi_add(wg->dev, &peer->napi, wg_packet_rx_poll, - NAPI_POLL_WEIGHT); +#ifdef COMPAT_NETIF_HAS_WEIGHT + netif_napi_add(wg->dev, &peer->napi, wg_packet_rx_poll, NAPI_POLL_WEIGHT); +#else + netif_napi_add(wg->dev, &peer->napi, wg_packet_rx_poll); +#endif napi_enable(&peer->napi); list_add_tail(&peer->peer_list, &wg->peer_list); INIT_LIST_HEAD(&peer->allowedips_list); diff --git a/src/queueing.c b/src/queueing.c index 8084e74..26d235d 100644 --- a/src/queueing.c +++ b/src/queueing.c @@ -28,6 +28,7 @@ int wg_packet_queue_init(struct crypt_queue *queue, work_func_t function, int ret; memset(queue, 0, sizeof(*queue)); + queue->last_cpu = -1; ret = ptr_ring_init(&queue->ring, len, GFP_KERNEL); if (ret) return ret; diff --git a/src/queueing.h b/src/queueing.h index 03850c4..6ae2a5f 100644 --- a/src/queueing.h +++ b/src/queueing.h @@ -75,20 +75,21 @@ static inline bool wg_check_packet_protocol(struct sk_buff *skb) static inline void wg_reset_packet(struct sk_buff *skb, bool encapsulating) { - const int pfmemalloc = skb->pfmemalloc; - u32 hash = skb->hash; u8 l4_hash = skb->l4_hash; u8 sw_hash = skb->sw_hash; - + u32 hash = skb->hash; skb_scrub_packet(skb, true); +#ifdef COMPAT_SKB_HAS_SKB_START memset(&skb->headers_start, 0, offsetof(struct sk_buff, headers_end) - offsetof(struct sk_buff, headers_start)); - skb->pfmemalloc = pfmemalloc; +#else + memset(&skb->headers, 0, sizeof(skb->headers)); +#endif if (encapsulating) { - skb->hash = hash; skb->l4_hash = l4_hash; skb->sw_hash = sw_hash; + skb->hash = hash; } skb->queue_mapping = 0; skb->nohdr = 0; @@ -111,7 +112,7 @@ static inline int wg_cpumask_choose_online(int *stored_cpu, unsigned int id) { unsigned int cpu = *stored_cpu, cpu_index, i; - if (unlikely(cpu == nr_cpumask_bits || + if (unlikely(cpu >= nr_cpu_ids || !cpumask_test_cpu(cpu, cpu_online_mask))) { cpu_index = id % cpumask_weight(cpu_online_mask); cpu = cpumask_first(cpu_online_mask); @@ -122,20 +123,17 @@ static inline int wg_cpumask_choose_online(int *stored_cpu, unsigned int id) return cpu; } -/* This function is racy, in the sense that next is unlocked, so it could return - * the same CPU twice. A race-free version of this would be to instead store an - * atomic sequence number, do an increment-and-return, and then iterate through - * every possible CPU until we get to that index -- choose_cpu. However that's - * a bit slower, and it doesn't seem like this potential race actually - * introduces any performance loss, so we live with it. +/* This function is racy, in the sense that it's called while last_cpu is + * unlocked, so it could return the same CPU twice. Adding locking or using + * atomic sequence numbers is slower though, and the consequences of racing are + * harmless, so live with it. */ -static inline int wg_cpumask_next_online(int *next) +static inline int wg_cpumask_next_online(int *last_cpu) { - int cpu = *next; - - while (unlikely(!cpumask_test_cpu(cpu, cpu_online_mask))) - cpu = cpumask_next(cpu, cpu_online_mask) % nr_cpumask_bits; - *next = cpumask_next(cpu, cpu_online_mask) % nr_cpumask_bits; + int cpu = cpumask_next(READ_ONCE(*last_cpu), cpu_online_mask); + if (cpu >= nr_cpu_ids) + cpu = cpumask_first(cpu_online_mask); + WRITE_ONCE(*last_cpu, cpu); return cpu; } @@ -164,7 +162,7 @@ static inline void wg_prev_queue_drop_peeked(struct prev_queue *queue) static inline int wg_queue_enqueue_per_device_and_peer( struct crypt_queue *device_queue, struct prev_queue *peer_queue, - struct sk_buff *skb, struct workqueue_struct *wq, int *next_cpu) + struct sk_buff *skb, struct workqueue_struct *wq) { int cpu; @@ -178,7 +176,7 @@ static inline int wg_queue_enqueue_per_device_and_peer( /* Then we queue it up in the device queue, which consumes the * packet as soon as it can. */ - cpu = wg_cpumask_next_online(next_cpu); + cpu = wg_cpumask_next_online(&device_queue->last_cpu); if (unlikely(ptr_ring_produce_bh(&device_queue->ring, skb))) return -EPIPE; queue_work_on(cpu, wq, &per_cpu_ptr(device_queue->worker, cpu)->work); diff --git a/src/receive.c b/src/receive.c index e42006d..e2b67d4 100644 --- a/src/receive.c +++ b/src/receive.c @@ -20,15 +20,8 @@ /* Must be called with bh disabled. */ static void update_rx_stats(struct wg_peer *peer, size_t len) { - struct pcpu_sw_netstats *tstats = - get_cpu_ptr(peer->device->dev->tstats); - - u64_stats_update_begin(&tstats->syncp); - ++tstats->rx_packets; - tstats->rx_bytes += len; + dev_sw_netstats_rx_add(peer->device->dev, len); peer->rx_bytes += len; - u64_stats_update_end(&tstats->syncp); - put_cpu_ptr(tstats); } static size_t validate_header_len(struct sk_buff *skb, struct wg_device *wg) @@ -293,7 +286,7 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair, if (unlikely(!READ_ONCE(keypair->receiving.is_valid) || wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) || - keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) { + READ_ONCE(keypair->receiving_counter.counter) >= REJECT_AFTER_MESSAGES)) { WRITE_ONCE(keypair->receiving.is_valid, false); return false; } @@ -305,7 +298,7 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair, * call skb_cow_data, so that there's no chance that data is removed * from the skb, so that later we can extract the original endpoint. */ - offset = skb->data - skb_network_header(skb); + offset = -skb_network_offset(skb); skb_push(skb, offset); num_frags = skb_cow_data(skb, 0, &trailer); offset += sizeof(struct message_data); @@ -318,9 +311,9 @@ static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair, return false; if (!chacha20poly1305_decrypt_sg_inplace(sg, skb->len, NULL, 0, - PACKET_CB(skb)->nonce, - keypair->receiving.key, - simd_context)) + PACKET_CB(skb)->nonce, + keypair->receiving.key + COMPAT_MAYBE_SIMD_CONTEXT(simd_context))) return false; /* Another ugly situation of pushing and pulling the header so as to @@ -361,7 +354,7 @@ static bool counter_validate(struct noise_replay_counter *counter, u64 their_cou for (i = 1; i <= top; ++i) counter->backtrack[(i + index_current) & ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0; - counter->counter = their_counter; + WRITE_ONCE(counter->counter, their_counter); } index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1; @@ -461,20 +454,20 @@ dishonest_packet_peer: net_dbg_skb_ratelimited("%s: Packet has unallowed src IP (%pISc) from peer %llu (%pISpfsc)\n", dev->name, skb, peer->internal_id, &peer->endpoint.addr); - ++dev->stats.rx_errors; - ++dev->stats.rx_frame_errors; + DEV_STATS_INC(dev, rx_errors); + DEV_STATS_INC(dev, rx_frame_errors); goto packet_processed; dishonest_packet_type: net_dbg_ratelimited("%s: Packet is neither ipv4 nor ipv6 from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr); - ++dev->stats.rx_errors; - ++dev->stats.rx_frame_errors; + DEV_STATS_INC(dev, rx_errors); + DEV_STATS_INC(dev, rx_frame_errors); goto packet_processed; dishonest_packet_size: net_dbg_ratelimited("%s: Packet has incorrect size from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr); - ++dev->stats.rx_errors; - ++dev->stats.rx_length_errors; + DEV_STATS_INC(dev, rx_errors); + DEV_STATS_INC(dev, rx_length_errors); goto packet_processed; packet_processed: dev_kfree_skb(skb); @@ -508,7 +501,7 @@ int wg_packet_rx_poll(struct napi_struct *napi, int budget) net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n", peer->device->dev->name, PACKET_CB(skb)->nonce, - keypair->receiving_counter.counter); + READ_ONCE(keypair->receiving_counter.counter)); goto next; } @@ -573,7 +566,7 @@ static void wg_packet_consume_data(struct wg_device *wg, struct sk_buff *skb) goto err; ret = wg_queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &peer->rx_queue, skb, - wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu); + wg->packet_crypt_wq); if (unlikely(ret == -EPIPE)) wg_queue_enqueue_per_peer_rx(skb, PACKET_STATE_DEAD); if (likely(!ret || ret == -EPIPE)) { diff --git a/src/send.c b/src/send.c index d6c27f1..50b17b8 100644 --- a/src/send.c +++ b/src/send.c @@ -3,6 +3,7 @@ * Copyright (C) 2015-2019 Jason A. Donenfeld . All Rights Reserved. */ +#include "compat/compat.h" #include "queueing.h" #include "timers.h" #include "device.h" @@ -219,8 +220,8 @@ static unsigned int calculate_skb_padding(struct sk_buff *skb) return padded_size - last_unit; } -static bool encrypt_packet(u32 message_type, struct sk_buff *skb, struct noise_keypair *keypair, - simd_context_t *simd_context) +static bool encrypt_packet(u32 message_type, struct sk_buff *skb, struct noise_keypair *keypair + COMPAT_MAYBE_SIMD_CONTEXT(simd_context_t *simd_context)) { unsigned int padding_len, plaintext_len, trailer_len; struct scatterlist sg[MAX_SKB_FRAGS + 8]; @@ -276,15 +277,15 @@ static bool encrypt_packet(u32 message_type, struct sk_buff *skb, struct noise_k return false; return chacha20poly1305_encrypt_sg_inplace(sg, plaintext_len, NULL, 0, PACKET_CB(skb)->nonce, - keypair->sending.key, - simd_context); + keypair->sending.key + COMPAT_MAYBE_SIMD_CONTEXT(simd_context)); } void wg_packet_send_keepalive(struct wg_peer *peer) { struct sk_buff *skb; - if (skb_queue_empty(&peer->staged_packet_queue)) { + if (skb_queue_empty_lockless(&peer->staged_packet_queue)) { skb = alloc_skb(DATA_PACKET_HEAD_ROOM + MESSAGE_MINIMUM_LENGTH, GFP_ATOMIC); if (unlikely(!skb)) @@ -352,9 +353,11 @@ void wg_packet_encrypt_worker(struct work_struct *work) work)->ptr; struct sk_buff *first, *skb, *next; struct wg_device *wg; - simd_context_t simd_context; +#ifdef COMPAT_CRYPTO_IS_ZINC + simd_context_t simd_context; simd_get(&simd_context); +#endif while ((first = ptr_ring_consume_bh(&queue->ring)) != NULL) { enum packet_state state = PACKET_STATE_CRYPTED; @@ -364,8 +367,8 @@ void wg_packet_encrypt_worker(struct work_struct *work) if (likely(encrypt_packet(PACKET_PEER(first)->advanced_security ? wg->advanced_security_config.transport_packet_magic_header : MESSAGE_DATA, skb, - PACKET_CB(first)->keypair, - &simd_context))) { + PACKET_CB(first)->keypair + COMPAT_MAYBE_SIMD_CONTEXT(&simd_context)))) { wg_reset_packet(skb, true); } else { state = PACKET_STATE_DEAD; @@ -374,9 +377,15 @@ void wg_packet_encrypt_worker(struct work_struct *work) } wg_queue_enqueue_per_peer_tx(first, state); +#ifdef COMPAT_CRYPTO_IS_ZINC simd_relax(&simd_context); +#endif + if (need_resched()) + cond_resched(); } +#ifdef COMPAT_CRYPTO_IS_ZINC simd_put(&simd_context); +#endif } static void wg_packet_create_data(struct wg_peer *peer, struct sk_buff *first) @@ -389,7 +398,7 @@ static void wg_packet_create_data(struct wg_peer *peer, struct sk_buff *first) goto err; ret = wg_queue_enqueue_per_device_and_peer(&wg->encrypt_queue, &peer->tx_queue, first, - wg->packet_crypt_wq, &wg->encrypt_queue.last_cpu); + wg->packet_crypt_wq); if (unlikely(ret == -EPIPE)) wg_queue_enqueue_per_peer_tx(first, PACKET_STATE_DEAD); err: @@ -404,7 +413,8 @@ err: void wg_packet_purge_staged_packets(struct wg_peer *peer) { spin_lock_bh(&peer->staged_packet_queue.lock); - peer->device->dev->stats.tx_dropped += peer->staged_packet_queue.qlen; + DEV_STATS_ADD(peer->device->dev, tx_dropped, + peer->staged_packet_queue.qlen); __skb_queue_purge(&peer->staged_packet_queue); spin_unlock_bh(&peer->staged_packet_queue.lock); } diff --git a/src/socket.c b/src/socket.c index 2dd574f..23fa6d1 100644 --- a/src/socket.c +++ b/src/socket.c @@ -49,7 +49,7 @@ static int send4(struct wg_device *wg, struct sk_buff *skb, rt = dst_cache_get_ip4(cache, &fl.saddr); if (!rt) { - security_sk_classify_flow(sock, flowi4_to_flowi(&fl)); + security_sk_classify_flow(sock, flowi4_to_flowi_common(&fl)); if (unlikely(!inet_confirm_addr(sock_net(sock), NULL, 0, fl.saddr, RT_SCOPE_HOST))) { endpoint->src4.s_addr = 0; @@ -129,7 +129,7 @@ static int send6(struct wg_device *wg, struct sk_buff *skb, dst = dst_cache_get_ip6(cache, &fl.saddr); if (!dst) { - security_sk_classify_flow(sock, flowi6_to_flowi(&fl)); + security_sk_classify_flow(sock, flowi6_to_flowi_common(&fl)); if (unlikely(!ipv6_addr_any(&fl.saddr) && !ipv6_chk_addr(sock_net(sock), &fl.saddr, NULL, 0))) { endpoint->src6 = fl.saddr = in6addr_any; diff --git a/src/timers.c b/src/timers.c index d54d32a..968bdb4 100644 --- a/src/timers.c +++ b/src/timers.c @@ -46,7 +46,7 @@ static void wg_expired_retransmit_handshake(struct timer_list *timer) if (peer->timer_handshake_attempts > MAX_TIMER_HANDSHAKES) { pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d attempts, giving up\n", peer->device->dev->name, peer->internal_id, - &peer->endpoint.addr, MAX_TIMER_HANDSHAKES + 2); + &peer->endpoint.addr, (int)MAX_TIMER_HANDSHAKES + 2); del_timer(&peer->timer_send_keepalive); /* We drop all packets without a keypair and don't try again, @@ -64,7 +64,7 @@ static void wg_expired_retransmit_handshake(struct timer_list *timer) ++peer->timer_handshake_attempts; pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d seconds, retrying (try %d)\n", peer->device->dev->name, peer->internal_id, - &peer->endpoint.addr, REKEY_TIMEOUT, + &peer->endpoint.addr, (int)REKEY_TIMEOUT, peer->timer_handshake_attempts + 1); /* We clear the endpoint address src address, in case this is @@ -94,7 +94,7 @@ static void wg_expired_new_handshake(struct timer_list *timer) pr_debug("%s: Retrying handshake with peer %llu (%pISpfsc) because we stopped hearing back after %d seconds\n", peer->device->dev->name, peer->internal_id, - &peer->endpoint.addr, KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); + &peer->endpoint.addr, (int)(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT)); /* We clear the endpoint address src address, in case this is the cause * of trouble. */ @@ -126,7 +126,7 @@ static void wg_queued_expired_zero_key_material(struct work_struct *work) pr_debug("%s: Zeroing out all keys for peer %llu (%pISpfsc), since we haven't received a new one in %d seconds\n", peer->device->dev->name, peer->internal_id, - &peer->endpoint.addr, REJECT_AFTER_TIME * 3); + &peer->endpoint.addr, (int)REJECT_AFTER_TIME * 3); wg_noise_handshake_clear(&peer->handshake); wg_noise_keypairs_clear(&peer->keypairs); wg_peer_put(peer); @@ -147,7 +147,7 @@ void wg_timers_data_sent(struct wg_peer *peer) if (!timer_pending(&peer->timer_new_handshake)) mod_peer_timer(peer, &peer->timer_new_handshake, jiffies + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * HZ + - prandom_u32_max(REKEY_TIMEOUT_JITTER_MAX_JIFFIES)); + get_random_u32_below(REKEY_TIMEOUT_JITTER_MAX_JIFFIES)); } /* Should be called after an authenticated data packet is received. */ @@ -183,7 +183,7 @@ void wg_timers_handshake_initiated(struct wg_peer *peer) { mod_peer_timer(peer, &peer->timer_retransmit_handshake, jiffies + REKEY_TIMEOUT * HZ + - prandom_u32_max(REKEY_TIMEOUT_JITTER_MAX_JIFFIES)); + get_random_u32_below(REKEY_TIMEOUT_JITTER_MAX_JIFFIES)); } /* Should be called after a handshake response message is received and processed @@ -234,10 +234,10 @@ void wg_timers_init(struct wg_peer *peer) void wg_timers_stop(struct wg_peer *peer) { - del_timer_sync(&peer->timer_retransmit_handshake); - del_timer_sync(&peer->timer_send_keepalive); - del_timer_sync(&peer->timer_new_handshake); - del_timer_sync(&peer->timer_zero_key_material); - del_timer_sync(&peer->timer_persistent_keepalive); + timer_delete_sync(&peer->timer_retransmit_handshake); + timer_delete_sync(&peer->timer_send_keepalive); + timer_delete_sync(&peer->timer_new_handshake); + timer_delete_sync(&peer->timer_zero_key_material); + timer_delete_sync(&peer->timer_persistent_keepalive); flush_work(&peer->clear_peer_work); }