From 85654e3dc02196f331c85ae2b6033bc1bf5083ac Mon Sep 17 00:00:00 2001 From: Odd Stranne Date: Wed, 12 May 2021 19:28:37 +0200 Subject: [PATCH] Use spinlock to protect IP addresses --- src/firewall/callouts.cpp | 4 ++-- src/firewall/context.h | 2 +- src/firewall/firewall.cpp | 12 +++++++----- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/firewall/callouts.cpp b/src/firewall/callouts.cpp index bd33381..edfc6c1 100644 --- a/src/firewall/callouts.cpp +++ b/src/firewall/callouts.cpp @@ -254,7 +254,7 @@ RewriteBind const bool ipv4 = FixedValues->layerId == FWPS_LAYER_ALE_BIND_REDIRECT_V4; - WdfWaitLockAcquire(Context->IpAddresses.Lock, NULL); + WdfSpinLockAcquire(Context->IpAddresses.Lock); if (ipv4) { @@ -307,7 +307,7 @@ RewriteBind } } - WdfWaitLockRelease(Context->IpAddresses.Lock); + WdfSpinLockRelease(Context->IpAddresses.Lock); Cleanup_data: diff --git a/src/firewall/context.h b/src/firewall/context.h index 60451aa..6b69ba3 100644 --- a/src/firewall/context.h +++ b/src/firewall/context.h @@ -13,7 +13,7 @@ namespace firewall struct IP_ADDRESSES_MGMT { - WDFWAITLOCK Lock; + WDFSPINLOCK Lock; ST_IP_ADDRESSES Addresses; SPLITTING_MODE SplittingMode; }; diff --git a/src/firewall/firewall.cpp b/src/firewall/firewall.cpp index f7f6af1..cb7b190 100644 --- a/src/firewall/firewall.cpp +++ b/src/firewall/firewall.cpp @@ -987,11 +987,11 @@ Initialize context->ProcessEventBroker = ProcessEventBroker; context->Eventing = Eventing; - auto status = WdfWaitLockCreate(WDF_NO_OBJECT_ATTRIBUTES, &context->IpAddresses.Lock); + auto status = WdfSpinLockCreate(WDF_NO_OBJECT_ATTRIBUTES, &context->IpAddresses.Lock); if (!NT_SUCCESS(status)) { - DbgPrint("WdfWaitLockCreate() failed 0x%X\n", status); + DbgPrint("WdfSpinLockCreate() failed 0x%X\n", status); context->IpAddresses.Lock = NULL; @@ -1429,12 +1429,14 @@ RegisterUpdatedIpAddresses goto Abort; } - WdfWaitLockAcquire(Context->IpAddresses.Lock, NULL); + auto intermediateNonPagedAddresses = *IpAddresses; - Context->IpAddresses.Addresses = *IpAddresses; + WdfSpinLockAcquire(Context->IpAddresses.Lock); + + Context->IpAddresses.Addresses = intermediateNonPagedAddresses; Context->IpAddresses.SplittingMode = newMode; - WdfWaitLockRelease(Context->IpAddresses.Lock); + WdfSpinLockRelease(Context->IpAddresses.Lock); Context->ActiveFilters = newActiveFilters;