fix(conn): preserve batching in ConcealBind.Send

Batch active UDP conceal output before forwarding to the inner bind so
StdNetBind can still use WriteBatch/sendmmsg and UDP GSO coalescing.

Preserve the no-op fast path by forwarding the original input batch
directly when the conceal pipeline is inactive. Clone prelude datagrams
before retaining them, keep encoded datagrams in pooled buffers until
their send completes, and defensively cap flushes by inner BatchSize.

Add regression tests for active conceal batching, exact flattened wire
order with prelude/junk datagrams, and no-op batch forwarding.
This commit is contained in:
Frog Rocky
2026-04-23 08:55:11 +02:00
committed by Yaroslav Gurov
parent 4d8f90b9af
commit 9e06d7e934
2 changed files with 181 additions and 5 deletions

View File

@@ -1,6 +1,7 @@
package conn
import (
"bytes"
"sync"
"github.com/amnezia-vpn/amneziawg-go/conceal"
@@ -126,9 +127,66 @@ func (b *ConcealBind) Send(bufs [][]byte, ep Endpoint) error {
return b.inner.Send(bufs, ep)
}
batchSize := b.inner.BatchSize()
if batchSize < 1 {
batchSize = 1
}
batch := make([][]byte, 0, batchSize)
retained := make([][]byte, 0, batchSize)
putRetained := func() {
for i, buf := range retained {
b.bufPool.Put(buf)
retained[i] = nil
}
retained = retained[:0]
}
clearBatch := func() {
for i := range batch {
batch[i] = nil
}
batch = batch[:0]
}
defer func() {
putRetained()
clearBatch()
}()
flush := func() error {
if len(batch) == 0 {
return nil
}
err := b.inner.Send(batch, ep)
putRetained()
clearBatch()
return err
}
appendPacket := func(packet []byte, retainedBuf []byte) error {
if len(batch) == batchSize {
if err := flush(); err != nil {
if retainedBuf != nil {
b.bufPool.Put(retainedBuf)
}
return err
}
}
batch = append(batch, packet)
if retainedBuf != nil {
retained = append(retained, retainedBuf)
}
if len(batch) == batchSize {
return flush()
}
return nil
}
for _, buf := range bufs {
if err := pipeline.EmitPrelude(buf, func(packet []byte) error {
return b.inner.Send([][]byte{packet}, ep)
return appendPacket(bytes.Clone(packet), nil)
}); err != nil {
return err
}
@@ -139,14 +197,12 @@ func (b *ConcealBind) Send(bufs [][]byte, ep Endpoint) error {
b.bufPool.Put(encoded)
return err
}
err = b.inner.Send([][]byte{encoded[:n]}, ep)
b.bufPool.Put(encoded)
if err != nil {
if err := appendPacket(encoded[:n], encoded); err != nil {
return err
}
}
return nil
return flush()
}
func (b *ConcealBind) ParseEndpoint(s string) (Endpoint, error) {

View File

@@ -2,6 +2,7 @@ package conn
import (
"bytes"
"encoding/binary"
"net"
"net/netip"
"slices"
@@ -57,6 +58,9 @@ func TestConcealBindNoOpWithoutOpts(t *testing.T) {
if len(inner.sendCalls) != 1 {
t.Fatalf("send calls = %d, want 1", len(inner.sendCalls))
}
if got := sendCallBatchLengths(inner.sendCalls); !slices.Equal(got, []int{2}) {
t.Fatalf("send batch lengths = %v, want [2]", got)
}
if got := inner.sendCalls[0].packets; !slices.EqualFunc(got, [][]byte{initiation, transport}, bytes.Equal) {
t.Fatalf("sent packets changed on no-op path")
}
@@ -87,6 +91,105 @@ func TestConcealBindNoOpWithoutOpts(t *testing.T) {
}
}
func TestConcealBindSendBatchesActivePipeline(t *testing.T) {
inner := &fakePacketBind{batchSize: 3}
bind := NewConcealBind(inner)
bind.SetFramedOpts(conceal.FramedOpts{
H1: mustHeader(t, "777"),
H2: mustHeader(t, "778"),
H4: mustHeader(t, "779"),
})
endpoint, err := bind.ParseEndpoint("127.0.0.1:51820")
if err != nil {
t.Fatalf("parse endpoint: %v", err)
}
initiation := makeInitiationPacket()
transport := makeTransportPacket()
response := makeResponsePacket()
if err := bind.Send([][]byte{initiation, transport, response}, endpoint); err != nil {
t.Fatalf("send: %v", err)
}
if got := sendCallBatchLengths(inner.sendCalls); !slices.Equal(got, []int{3}) {
t.Fatalf("send batch lengths = %v, want [3]", got)
}
wirePackets := flattenSendCalls(inner.sendCalls)
if len(wirePackets) != 3 {
t.Fatalf("wire packet count = %d, want 3", len(wirePackets))
}
if got := wireHeader(t, wirePackets[0]); got != 777 {
t.Fatalf("wire packet 0 header = %d, want 777", got)
}
if got := wireHeader(t, wirePackets[1]); got != 779 {
t.Fatalf("wire packet 1 header = %d, want 779", got)
}
if got := wireHeader(t, wirePackets[2]); got != 778 {
t.Fatalf("wire packet 2 header = %d, want 778", got)
}
}
func TestConcealBindSendBatchesPreludeInWireOrder(t *testing.T) {
inner := &fakePacketBind{batchSize: 3}
bind := NewConcealBind(inner)
bind.SetFramedOpts(conceal.FramedOpts{
H1: mustHeader(t, "777"),
H4: mustHeader(t, "779"),
})
bind.SetPreludeOpts(conceal.PreludeOpts{
Jc: 1,
Jmin: 3,
Jmax: 3,
RulesArr: [5]conceal.Rules{
mustParseRules(t, "<b 0xaabb>"),
},
})
endpoint, err := bind.ParseEndpoint("127.0.0.1:51820")
if err != nil {
t.Fatalf("parse endpoint: %v", err)
}
firstInitiation := makeInitiationPacket()
transport := makeTransportPacket()
secondInitiation := makeInitiationPacket()
if err := bind.Send([][]byte{firstInitiation, transport, secondInitiation}, endpoint); err != nil {
t.Fatalf("send: %v", err)
}
if got := sendCallBatchLengths(inner.sendCalls); !slices.Equal(got, []int{3, 3, 1}) {
t.Fatalf("send batch lengths = %v, want [3 3 1]", got)
}
wirePackets := flattenSendCalls(inner.sendCalls)
if len(wirePackets) != 7 {
t.Fatalf("wire packet count = %d, want 7", len(wirePackets))
}
if !bytes.Equal(wirePackets[0], []byte{0xaa, 0xbb}) {
t.Fatalf("wire packet 0 prelude = %x, want aabb", wirePackets[0])
}
if len(wirePackets[1]) != 3 {
t.Fatalf("wire packet 1 junk len = %d, want 3", len(wirePackets[1]))
}
if got := wireHeader(t, wirePackets[2]); got != 777 {
t.Fatalf("wire packet 2 header = %d, want 777", got)
}
if got := wireHeader(t, wirePackets[3]); got != 779 {
t.Fatalf("wire packet 3 header = %d, want 779", got)
}
if !bytes.Equal(wirePackets[4], []byte{0xaa, 0xbb}) {
t.Fatalf("wire packet 4 prelude = %x, want aabb", wirePackets[4])
}
if len(wirePackets[5]) != 3 {
t.Fatalf("wire packet 5 junk len = %d, want 3", len(wirePackets[5]))
}
if got := wireHeader(t, wirePackets[6]); got != 777 {
t.Fatalf("wire packet 6 header = %d, want 777", got)
}
}
func TestConcealBindSendAndReceive(t *testing.T) {
senderInner := &fakePacketBind{batchSize: 4}
sender := newTestConcealBind(t, senderInner)
@@ -387,5 +490,22 @@ func flattenSendCalls(calls []fakeSendCall) [][]byte {
return out
}
func sendCallBatchLengths(calls []fakeSendCall) []int {
lengths := make([]int, 0, len(calls))
for _, call := range calls {
lengths = append(lengths, len(call.packets))
}
return lengths
}
func wireHeader(t *testing.T, packet []byte) uint32 {
t.Helper()
if len(packet) < 4 {
t.Fatalf("wire packet len = %d, want at least 4", len(packet))
}
return binary.LittleEndian.Uint32(packet[:4])
}
var _ Bind = (*fakePacketBind)(nil)
var _ Endpoint = (*fakePacketEndpoint)(nil)