mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2026-05-17 08:15:49 +03:00
fix(conn): preserve batching in ConcealBind.Send
Batch active UDP conceal output before forwarding to the inner bind so StdNetBind can still use WriteBatch/sendmmsg and UDP GSO coalescing. Preserve the no-op fast path by forwarding the original input batch directly when the conceal pipeline is inactive. Clone prelude datagrams before retaining them, keep encoded datagrams in pooled buffers until their send completes, and defensively cap flushes by inner BatchSize. Add regression tests for active conceal batching, exact flattened wire order with prelude/junk datagrams, and no-op batch forwarding.
This commit is contained in:
committed by
Yaroslav Gurov
parent
4d8f90b9af
commit
9e06d7e934
@@ -1,6 +1,7 @@
|
||||
package conn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sync"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conceal"
|
||||
@@ -126,9 +127,66 @@ func (b *ConcealBind) Send(bufs [][]byte, ep Endpoint) error {
|
||||
return b.inner.Send(bufs, ep)
|
||||
}
|
||||
|
||||
batchSize := b.inner.BatchSize()
|
||||
if batchSize < 1 {
|
||||
batchSize = 1
|
||||
}
|
||||
|
||||
batch := make([][]byte, 0, batchSize)
|
||||
retained := make([][]byte, 0, batchSize)
|
||||
|
||||
putRetained := func() {
|
||||
for i, buf := range retained {
|
||||
b.bufPool.Put(buf)
|
||||
retained[i] = nil
|
||||
}
|
||||
retained = retained[:0]
|
||||
}
|
||||
clearBatch := func() {
|
||||
for i := range batch {
|
||||
batch[i] = nil
|
||||
}
|
||||
batch = batch[:0]
|
||||
}
|
||||
defer func() {
|
||||
putRetained()
|
||||
clearBatch()
|
||||
}()
|
||||
|
||||
flush := func() error {
|
||||
if len(batch) == 0 {
|
||||
return nil
|
||||
}
|
||||
err := b.inner.Send(batch, ep)
|
||||
putRetained()
|
||||
clearBatch()
|
||||
return err
|
||||
}
|
||||
|
||||
appendPacket := func(packet []byte, retainedBuf []byte) error {
|
||||
if len(batch) == batchSize {
|
||||
if err := flush(); err != nil {
|
||||
if retainedBuf != nil {
|
||||
b.bufPool.Put(retainedBuf)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
batch = append(batch, packet)
|
||||
if retainedBuf != nil {
|
||||
retained = append(retained, retainedBuf)
|
||||
}
|
||||
|
||||
if len(batch) == batchSize {
|
||||
return flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, buf := range bufs {
|
||||
if err := pipeline.EmitPrelude(buf, func(packet []byte) error {
|
||||
return b.inner.Send([][]byte{packet}, ep)
|
||||
return appendPacket(bytes.Clone(packet), nil)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -139,14 +197,12 @@ func (b *ConcealBind) Send(bufs [][]byte, ep Endpoint) error {
|
||||
b.bufPool.Put(encoded)
|
||||
return err
|
||||
}
|
||||
err = b.inner.Send([][]byte{encoded[:n]}, ep)
|
||||
b.bufPool.Put(encoded)
|
||||
if err != nil {
|
||||
if err := appendPacket(encoded[:n], encoded); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return flush()
|
||||
}
|
||||
|
||||
func (b *ConcealBind) ParseEndpoint(s string) (Endpoint, error) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package conn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
@@ -57,6 +58,9 @@ func TestConcealBindNoOpWithoutOpts(t *testing.T) {
|
||||
if len(inner.sendCalls) != 1 {
|
||||
t.Fatalf("send calls = %d, want 1", len(inner.sendCalls))
|
||||
}
|
||||
if got := sendCallBatchLengths(inner.sendCalls); !slices.Equal(got, []int{2}) {
|
||||
t.Fatalf("send batch lengths = %v, want [2]", got)
|
||||
}
|
||||
if got := inner.sendCalls[0].packets; !slices.EqualFunc(got, [][]byte{initiation, transport}, bytes.Equal) {
|
||||
t.Fatalf("sent packets changed on no-op path")
|
||||
}
|
||||
@@ -87,6 +91,105 @@ func TestConcealBindNoOpWithoutOpts(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcealBindSendBatchesActivePipeline(t *testing.T) {
|
||||
inner := &fakePacketBind{batchSize: 3}
|
||||
bind := NewConcealBind(inner)
|
||||
bind.SetFramedOpts(conceal.FramedOpts{
|
||||
H1: mustHeader(t, "777"),
|
||||
H2: mustHeader(t, "778"),
|
||||
H4: mustHeader(t, "779"),
|
||||
})
|
||||
|
||||
endpoint, err := bind.ParseEndpoint("127.0.0.1:51820")
|
||||
if err != nil {
|
||||
t.Fatalf("parse endpoint: %v", err)
|
||||
}
|
||||
|
||||
initiation := makeInitiationPacket()
|
||||
transport := makeTransportPacket()
|
||||
response := makeResponsePacket()
|
||||
if err := bind.Send([][]byte{initiation, transport, response}, endpoint); err != nil {
|
||||
t.Fatalf("send: %v", err)
|
||||
}
|
||||
|
||||
if got := sendCallBatchLengths(inner.sendCalls); !slices.Equal(got, []int{3}) {
|
||||
t.Fatalf("send batch lengths = %v, want [3]", got)
|
||||
}
|
||||
|
||||
wirePackets := flattenSendCalls(inner.sendCalls)
|
||||
if len(wirePackets) != 3 {
|
||||
t.Fatalf("wire packet count = %d, want 3", len(wirePackets))
|
||||
}
|
||||
if got := wireHeader(t, wirePackets[0]); got != 777 {
|
||||
t.Fatalf("wire packet 0 header = %d, want 777", got)
|
||||
}
|
||||
if got := wireHeader(t, wirePackets[1]); got != 779 {
|
||||
t.Fatalf("wire packet 1 header = %d, want 779", got)
|
||||
}
|
||||
if got := wireHeader(t, wirePackets[2]); got != 778 {
|
||||
t.Fatalf("wire packet 2 header = %d, want 778", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcealBindSendBatchesPreludeInWireOrder(t *testing.T) {
|
||||
inner := &fakePacketBind{batchSize: 3}
|
||||
bind := NewConcealBind(inner)
|
||||
bind.SetFramedOpts(conceal.FramedOpts{
|
||||
H1: mustHeader(t, "777"),
|
||||
H4: mustHeader(t, "779"),
|
||||
})
|
||||
bind.SetPreludeOpts(conceal.PreludeOpts{
|
||||
Jc: 1,
|
||||
Jmin: 3,
|
||||
Jmax: 3,
|
||||
RulesArr: [5]conceal.Rules{
|
||||
mustParseRules(t, "<b 0xaabb>"),
|
||||
},
|
||||
})
|
||||
|
||||
endpoint, err := bind.ParseEndpoint("127.0.0.1:51820")
|
||||
if err != nil {
|
||||
t.Fatalf("parse endpoint: %v", err)
|
||||
}
|
||||
|
||||
firstInitiation := makeInitiationPacket()
|
||||
transport := makeTransportPacket()
|
||||
secondInitiation := makeInitiationPacket()
|
||||
if err := bind.Send([][]byte{firstInitiation, transport, secondInitiation}, endpoint); err != nil {
|
||||
t.Fatalf("send: %v", err)
|
||||
}
|
||||
|
||||
if got := sendCallBatchLengths(inner.sendCalls); !slices.Equal(got, []int{3, 3, 1}) {
|
||||
t.Fatalf("send batch lengths = %v, want [3 3 1]", got)
|
||||
}
|
||||
|
||||
wirePackets := flattenSendCalls(inner.sendCalls)
|
||||
if len(wirePackets) != 7 {
|
||||
t.Fatalf("wire packet count = %d, want 7", len(wirePackets))
|
||||
}
|
||||
if !bytes.Equal(wirePackets[0], []byte{0xaa, 0xbb}) {
|
||||
t.Fatalf("wire packet 0 prelude = %x, want aabb", wirePackets[0])
|
||||
}
|
||||
if len(wirePackets[1]) != 3 {
|
||||
t.Fatalf("wire packet 1 junk len = %d, want 3", len(wirePackets[1]))
|
||||
}
|
||||
if got := wireHeader(t, wirePackets[2]); got != 777 {
|
||||
t.Fatalf("wire packet 2 header = %d, want 777", got)
|
||||
}
|
||||
if got := wireHeader(t, wirePackets[3]); got != 779 {
|
||||
t.Fatalf("wire packet 3 header = %d, want 779", got)
|
||||
}
|
||||
if !bytes.Equal(wirePackets[4], []byte{0xaa, 0xbb}) {
|
||||
t.Fatalf("wire packet 4 prelude = %x, want aabb", wirePackets[4])
|
||||
}
|
||||
if len(wirePackets[5]) != 3 {
|
||||
t.Fatalf("wire packet 5 junk len = %d, want 3", len(wirePackets[5]))
|
||||
}
|
||||
if got := wireHeader(t, wirePackets[6]); got != 777 {
|
||||
t.Fatalf("wire packet 6 header = %d, want 777", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcealBindSendAndReceive(t *testing.T) {
|
||||
senderInner := &fakePacketBind{batchSize: 4}
|
||||
sender := newTestConcealBind(t, senderInner)
|
||||
@@ -387,5 +490,22 @@ func flattenSendCalls(calls []fakeSendCall) [][]byte {
|
||||
return out
|
||||
}
|
||||
|
||||
func sendCallBatchLengths(calls []fakeSendCall) []int {
|
||||
lengths := make([]int, 0, len(calls))
|
||||
for _, call := range calls {
|
||||
lengths = append(lengths, len(call.packets))
|
||||
}
|
||||
return lengths
|
||||
}
|
||||
|
||||
func wireHeader(t *testing.T, packet []byte) uint32 {
|
||||
t.Helper()
|
||||
|
||||
if len(packet) < 4 {
|
||||
t.Fatalf("wire packet len = %d, want at least 4", len(packet))
|
||||
}
|
||||
return binary.LittleEndian.Uint32(packet[:4])
|
||||
}
|
||||
|
||||
var _ Bind = (*fakePacketBind)(nil)
|
||||
var _ Endpoint = (*fakePacketEndpoint)(nil)
|
||||
|
||||
Reference in New Issue
Block a user