perf(conceal): optimize record and prelude hot paths

Replace hot-path bytes.NewBuffer usage with slice-backed readers and
writers, and build FlexBuffer-backed read/write contexts by value.

Remove defer from the targeted per-message and per-batch write paths,
return pooled buffers explicitly after downstream writes complete, and
keep TCP/UDP prelude behavior unchanged.

Add focused regression tests for stream, UDP, and batch masquerade and
prelude writes, and fix write-side bookkeeping to emit the actual
encoded bytes.
This commit is contained in:
Frog Rocky
2026-03-27 14:59:19 +01:00
committed by Yaroslav Gurov
parent c65f05a009
commit f4cdedf40a
7 changed files with 436 additions and 52 deletions

View File

@@ -178,15 +178,15 @@ func benchmarkEncodeFramedRecord(opts FramedOpts, payload []byte) []byte {
func benchmarkEncodeMasqueradeRecord(rules Rules, payload []byte) []byte {
pool := benchmarkNewBufferPool()
ctx := &writeContext{
ctx := writeContext{
FlexBuffer: WrapFlexBuffer(payload),
BufferPool: WrapBufferPool(pool),
}
tmp := pool.Get().([]byte)
defer pool.Put(tmp)
w := bytes.NewBuffer(tmp[:0])
if err := rules.Write(w, ctx); err != nil {
w := newSliceWriter(tmp)
if err := rules.Write(&w, &ctx); err != nil {
panic(err)
}
return append([]byte(nil), w.Bytes()...)

View File

@@ -2,8 +2,8 @@ package conceal
import "sync"
func WrapFlexBuffer(buf []byte) *FlexBuffer {
return &FlexBuffer{
func WrapFlexBuffer(buf []byte) FlexBuffer {
return FlexBuffer{
buf: buf,
}
}

View File

@@ -1,7 +1,6 @@
package conceal
import (
"bytes"
"net"
"sync"
@@ -48,12 +47,12 @@ func (c *MasqueradeConn) ReadRecord(b []byte) (n int, err error) {
return 0, ErrNoReadRecord
}
ctx := &readContext{
ctx := readContext{
FlexBuffer: WrapFlexBuffer(b),
BufferPool: c.pool,
}
if err := c.rulesIn.Read(c.Conn, ctx); err != nil {
if err := c.rulesIn.Read(c.Conn, &ctx); err != nil {
return 0, err
}
@@ -65,25 +64,26 @@ func (c *MasqueradeConn) WriteRecord(b []byte) (n int, err error) {
return 0, ErrNoWriteRecord
}
ctx := &writeContext{
ctx := writeContext{
FlexBuffer: WrapFlexBuffer(b),
BufferPool: c.pool,
}
t := c.pool.Get()
defer c.pool.Put(t)
w := newSliceWriter(t)
w := bytes.NewBuffer(t[:0])
if err := c.rulesOut.Write(w, ctx); err != nil {
if err := c.rulesOut.Write(&w, &ctx); err != nil {
c.pool.Put(t)
return 0, err
}
if _, err := c.Conn.Write(w.Bytes()); err != nil {
c.pool.Put(t)
return 0, err
}
return ctx.Len(), nil
c.pool.Put(t)
return len(b), nil
}
func (c *MasqueradeConn) Read(b []byte) (n int, err error) {
@@ -126,13 +126,13 @@ func (c *MasqueradeUDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr
return n, oobn, flags, addr, err
}
r := bytes.NewBuffer(b[:n])
ctx := &readContext{
r := newSliceReader(b[:n])
ctx := readContext{
FlexBuffer: WrapFlexBuffer(b),
BufferPool: c.pool,
}
if err = c.rulesIn.Read(r, ctx); err != nil {
if err = c.rulesIn.Read(&r, &ctx); err != nil {
return 0, oobn, flags, addr, err
}
@@ -141,19 +141,23 @@ func (c *MasqueradeUDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr
func (c *MasqueradeUDPConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) {
t := c.pool.Get()
defer c.pool.Put(t)
w := bytes.NewBuffer(t[:0])
ctx := &writeContext{
w := newSliceWriter(t)
ctx := writeContext{
FlexBuffer: WrapFlexBuffer(b),
BufferPool: c.pool,
}
if err = c.rulesOut.Write(w, ctx); err != nil {
if err = c.rulesOut.Write(&w, &ctx); err != nil {
c.pool.Put(t)
return 0, 0, err
}
return c.UDPConn.WriteMsgUDP(t[:ctx.Len()], oob, addr)
n, oobn, err = c.UDPConn.WriteMsgUDP(w.Bytes(), oob, addr)
c.pool.Put(t)
if err != nil {
return 0, oobn, err
}
return len(b), oobn, nil
}
func NewMasqueradeBatchConn(conn BatchConn, bp *sync.Pool, opts MasqueradeOpts) (c *MasqueradeBatchConn, ok bool) {
@@ -183,13 +187,13 @@ func (c *MasqueradeBatchConn) ReadBatch(ms []ipv4.Message, flags int) (n int, er
}
for i := range n {
r := bytes.NewBuffer(ms[i].Buffers[0][:ms[i].N])
ctx := &readContext{
r := newSliceReader(ms[i].Buffers[0][:ms[i].N])
ctx := readContext{
FlexBuffer: WrapFlexBuffer(ms[i].Buffers[0]),
BufferPool: c.pool,
}
if err = c.rulesIn.Read(r, ctx); err != nil {
if err = c.rulesIn.Read(&r, &ctx); err != nil {
return 0, err
}
@@ -200,22 +204,35 @@ func (c *MasqueradeBatchConn) ReadBatch(ms []ipv4.Message, flags int) (n int, er
}
func (c *MasqueradeBatchConn) WriteBatch(ms []ipv4.Message, flags int) (n int, err error) {
var inline [128][]byte
pooled := inline[:0]
if len(ms) > len(inline) {
pooled = make([][]byte, 0, len(ms))
}
for i := range ms {
t := c.pool.Get()
defer c.pool.Put(t)
pooled = append(pooled, t)
w := bytes.NewBuffer(t[:0])
ctx := &writeContext{
w := newSliceWriter(t)
ctx := writeContext{
FlexBuffer: WrapFlexBuffer(ms[i].Buffers[0]),
BufferPool: c.pool,
}
if err = c.rulesOut.Write(w, ctx); err != nil {
if err = c.rulesOut.Write(&w, &ctx); err != nil {
for _, buf := range pooled {
c.pool.Put(buf)
}
return 0, err
}
ms[i].Buffers[0] = w.Bytes()
}
return c.BatchConn.WriteBatch(ms, flags)
n, err = c.BatchConn.WriteBatch(ms, flags)
for _, buf := range pooled {
c.pool.Put(buf)
}
return n, err
}

View File

@@ -0,0 +1,295 @@
package conceal
import (
"bytes"
"encoding/binary"
"errors"
"net"
"sync"
"syscall"
"testing"
"time"
"golang.org/x/net/ipv4"
)
func TestMasqueradeConnWriteRecordEncodesPayload(t *testing.T) {
rules := mustTestRules(t, "<dz be 2><d>")
raw := &recordingConn{}
pool := newTestBufferPool()
conn, ok := NewMasqueradeConn(raw, pool, MasqueradeOpts{RulesOut: rules})
if !ok {
t.Fatal("expected masquerade conn")
}
payload := []byte{0x01, 0x02, 0x03}
n, err := conn.WriteRecord(payload)
if err != nil {
t.Fatalf("WriteRecord failed: %v", err)
}
if n != len(payload) {
t.Fatalf("WriteRecord n = %d, want %d", n, len(payload))
}
want := []byte{0x00, 0x03, 0x01, 0x02, 0x03}
if !bytes.Equal(raw.writes[0], want) {
t.Fatalf("WriteRecord bytes = %x, want %x", raw.writes[0], want)
}
}
func TestMasqueradeUDPConnWriteMsgUDPEncodesPayload(t *testing.T) {
rules := mustTestRules(t, "<dz be 2><d>")
raw := &recordingUDPConn{}
pool := newTestBufferPool()
conn, ok := NewMasqueradeUDPConn(raw, pool, MasqueradeOpts{RulesOut: rules})
if !ok {
t.Fatal("expected masquerade udp conn")
}
payload := []byte{0x01, 0x02, 0x03}
n, oobn, err := conn.WriteMsgUDP(payload, []byte{0xaa}, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 51820})
if err != nil {
t.Fatalf("WriteMsgUDP failed: %v", err)
}
if n != len(payload) {
t.Fatalf("WriteMsgUDP n = %d, want %d", n, len(payload))
}
if oobn != 1 {
t.Fatalf("WriteMsgUDP oobn = %d, want 1", oobn)
}
want := []byte{0x00, 0x03, 0x01, 0x02, 0x03}
if len(raw.writes) != 1 {
t.Fatalf("write count = %d, want 1", len(raw.writes))
}
if !bytes.Equal(raw.writes[0], want) {
t.Fatalf("WriteMsgUDP bytes = %x, want %x", raw.writes[0], want)
}
}
func TestMasqueradeBatchConnWriteBatchEncodesEachMessage(t *testing.T) {
rules := mustTestRules(t, "<dz be 2><d>")
raw := &recordingBatchConn{}
pool := newTestBufferPool()
conn, ok := NewMasqueradeBatchConn(raw, pool, MasqueradeOpts{RulesOut: rules})
if !ok {
t.Fatal("expected masquerade batch conn")
}
msgs := []ipv4.Message{
{Buffers: net.Buffers{[]byte{0x01, 0x02}}},
{Buffers: net.Buffers{[]byte{0x03}}},
}
n, err := conn.WriteBatch(msgs, 0)
if err != nil {
t.Fatalf("WriteBatch failed: %v", err)
}
if n != len(msgs) {
t.Fatalf("WriteBatch n = %d, want %d", n, len(msgs))
}
if len(raw.batches) != 1 {
t.Fatalf("batch count = %d, want 1", len(raw.batches))
}
want := [][]byte{
{0x00, 0x02, 0x01, 0x02},
{0x00, 0x01, 0x03},
}
for i, got := range raw.batches[0] {
if !bytes.Equal(got.data, want[i]) {
t.Fatalf("batch msg %d = %x, want %x", i, got.data, want[i])
}
}
}
func TestPreludeBatchConnWriteBatchEmitsPreludeBeforeInitiation(t *testing.T) {
raw := &recordingBatchConn{}
pool := newTestBufferPool()
msgsPool := newTestMsgsPool()
rules := mustTestRules(t, "<b 0xaabb>")
conn, ok := NewPreludeBatchConn(raw, raw, pool, msgsPool, nil, PreludeOpts{
Jc: 1,
Jmin: 3,
Jmax: 3,
RulesArr: [5]Rules{
rules,
},
})
if !ok {
t.Fatal("expected prelude batch conn")
}
initiation := make([]byte, 8)
binary.LittleEndian.PutUint32(initiation[:4], WireguardMsgInitiationType)
msgs := []ipv4.Message{
{
Buffers: net.Buffers{initiation},
OOB: []byte{0x44},
Addr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 51820},
},
}
n, err := conn.WriteBatch(msgs, 0)
if err != nil {
t.Fatalf("WriteBatch failed: %v", err)
}
if n != len(msgs) {
t.Fatalf("WriteBatch n = %d, want %d", n, len(msgs))
}
if len(raw.batches) != 2 {
t.Fatalf("batch count = %d, want 2", len(raw.batches))
}
preludeBatch := raw.batches[0]
if len(preludeBatch) != 2 {
t.Fatalf("prelude batch len = %d, want 2", len(preludeBatch))
}
if !bytes.Equal(preludeBatch[0].data, []byte{0xaa, 0xbb}) {
t.Fatalf("prelude decoy = %x, want aabb", preludeBatch[0].data)
}
if len(preludeBatch[1].data) != 3 {
t.Fatalf("junk len = %d, want 3", len(preludeBatch[1].data))
}
if !bytes.Equal(raw.batches[1][0].data, initiation) {
t.Fatalf("main batch payload changed")
}
}
func mustTestRules(t *testing.T, spec string) Rules {
t.Helper()
rules, err := ParseRules(spec)
if err != nil {
t.Fatalf("ParseRules(%q): %v", spec, err)
}
return rules
}
func newTestBufferPool() *sync.Pool {
return &sync.Pool{
New: func() any {
return make([]byte, 256)
},
}
}
func newTestMsgsPool() *sync.Pool {
return &sync.Pool{
New: func() any {
msgs := make([]ipv4.Message, 8)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
}
return &msgs
},
}
}
type recordingConn struct {
writes [][]byte
}
func (c *recordingConn) Read(_ []byte) (int, error) {
return 0, errors.New("not implemented")
}
func (c *recordingConn) Write(b []byte) (int, error) {
c.writes = append(c.writes, append([]byte(nil), b...))
return len(b), nil
}
func (c *recordingConn) Close() error {
return nil
}
func (c *recordingConn) LocalAddr() net.Addr {
return &net.TCPAddr{}
}
func (c *recordingConn) RemoteAddr() net.Addr {
return &net.TCPAddr{}
}
func (c *recordingConn) SetDeadline(time.Time) error {
return nil
}
func (c *recordingConn) SetReadDeadline(time.Time) error {
return nil
}
func (c *recordingConn) SetWriteDeadline(time.Time) error {
return nil
}
type recordingUDPConn struct {
writes [][]byte
}
func (c *recordingUDPConn) ReadFrom([]byte) (int, net.Addr, error) {
return 0, nil, errors.New("not implemented")
}
func (c *recordingUDPConn) WriteTo([]byte, net.Addr) (int, error) {
return 0, errors.New("not implemented")
}
func (c *recordingUDPConn) Close() error {
return nil
}
func (c *recordingUDPConn) LocalAddr() net.Addr {
return &net.UDPAddr{}
}
func (c *recordingUDPConn) SetDeadline(time.Time) error {
return nil
}
func (c *recordingUDPConn) SetReadDeadline(time.Time) error {
return nil
}
func (c *recordingUDPConn) SetWriteDeadline(time.Time) error {
return nil
}
func (c *recordingUDPConn) ReadMsgUDP(_, _ []byte) (int, int, int, *net.UDPAddr, error) {
return 0, 0, 0, nil, errors.New("not implemented")
}
func (c *recordingUDPConn) WriteMsgUDP(b, _ []byte, _ *net.UDPAddr) (int, int, error) {
c.writes = append(c.writes, append([]byte(nil), b...))
return len(b), 1, nil
}
func (c *recordingUDPConn) SyscallConn() (syscall.RawConn, error) {
return nil, errors.New("not implemented")
}
type recordedBatchMessage struct {
data []byte
}
type recordingBatchConn struct {
batches [][]recordedBatchMessage
}
func (c *recordingBatchConn) ReadBatch([]ipv4.Message, int) (int, error) {
return 0, errors.New("not implemented")
}
func (c *recordingBatchConn) WriteBatch(ms []ipv4.Message, flags int) (int, error) {
batch := make([]recordedBatchMessage, len(ms))
for i := range ms {
batch[i] = recordedBatchMessage{
data: append([]byte(nil), ms[i].Buffers[0]...),
}
}
c.batches = append(c.batches, batch)
return len(ms), nil
}

View File

@@ -1,7 +1,6 @@
package conceal
import (
"bytes"
"crypto/rand"
"encoding/binary"
"math/big"
@@ -103,35 +102,39 @@ func (c *PreludeUDPConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn
}
if isInit {
b := c.pool.Get()
defer c.pool.Put(b)
ctx := &writeContext{
buf := c.pool.Get()
ctx := writeContext{
FlexBuffer: WrapFlexBuffer(nil),
BufferPool: c.pool,
}
w := newSliceWriter(buf)
for _, rules := range c.rulesArr {
if rules == nil {
continue
}
w := bytes.NewBuffer(b[:0])
if err = rules.Write(w, ctx); err != nil {
w.Reset(buf)
if err = rules.Write(&w, &ctx); err != nil {
c.pool.Put(buf)
return 0, 0, err
}
if _, _, err = c.origin.WriteMsgUDP(w.Bytes(), oob, addr); err != nil {
c.pool.Put(buf)
return 0, 0, err
}
}
for range c.junkCount {
junk := c.junkGen.generate(b)
junk := c.junkGen.generate(buf)
if _, _, err = c.origin.WriteMsgUDP(junk, oob, addr); err != nil {
c.pool.Put(buf)
return 0, 0, err
}
}
c.pool.Put(buf)
}
return c.UDPConn.WriteMsgUDP(b, oob, addr)
@@ -194,28 +197,30 @@ func (c *PreludeConn) Write(b []byte) (n int, err error) {
func (c *PreludeConn) writePreludeRecords() (err error) {
buf := c.pool.Get()
defer c.pool.Put(buf)
ctx := &writeContext{
ctx := writeContext{
FlexBuffer: WrapFlexBuffer(nil),
BufferPool: c.pool,
}
w := newSliceWriter(buf)
for _, rules := range c.rulesArr {
if rules == nil {
continue
}
w := bytes.NewBuffer(buf[:0])
if err = rules.Write(w, ctx); err != nil {
w.Reset(buf)
if err = rules.Write(&w, &ctx); err != nil {
c.pool.Put(buf)
return err
}
if _, err = c.StreamRecordConn.WriteRecord(w.Bytes()); err != nil {
c.pool.Put(buf)
return err
}
}
c.pool.Put(buf)
return nil
}
@@ -278,13 +283,25 @@ func (c *PreludeBatchConn) WriteBatch(ms []ipv4.Message, flags int) (n int, err
}
if initMsg != nil {
ctx := &writeContext{
ctx := writeContext{
FlexBuffer: WrapFlexBuffer(nil),
BufferPool: c.bufPool,
}
msgs := c.msgsPool.Get().(*[]ipv4.Message)
defer c.msgsPool.Put(msgs)
count := c.junkCount
for _, rules := range c.rulesArr {
if rules != nil {
count++
}
}
var inline [32][]byte
pooled := inline[:0]
if count > len(inline) {
pooled = make([][]byte, 0, count)
}
i := 0
for _, rules := range c.rulesArr {
@@ -293,10 +310,14 @@ func (c *PreludeBatchConn) WriteBatch(ms []ipv4.Message, flags int) (n int, err
}
buf := c.bufPool.Get()
defer c.bufPool.Put(buf)
pooled = append(pooled, buf)
w := bytes.NewBuffer(buf[:0])
if err = rules.Write(w, ctx); err != nil {
w := newSliceWriter(buf)
if err = rules.Write(&w, &ctx); err != nil {
for _, pooledBuf := range pooled {
c.bufPool.Put(pooledBuf)
}
c.msgsPool.Put(msgs)
return 0, err
}
@@ -308,7 +329,7 @@ func (c *PreludeBatchConn) WriteBatch(ms []ipv4.Message, flags int) (n int, err
for range c.junkCount {
buf := c.bufPool.Get()
defer c.bufPool.Put(buf)
pooled = append(pooled, buf)
(*msgs)[i].Buffers[0] = c.junkGen.generate(buf)
(*msgs)[i].OOB = initMsg.OOB
@@ -321,6 +342,10 @@ func (c *PreludeBatchConn) WriteBatch(ms []ipv4.Message, flags int) (n int, err
m := (*msgs)[start:i]
n, err = c.origin.WriteBatch(m, flags)
if err != nil {
for _, pooledBuf := range pooled {
c.bufPool.Put(pooledBuf)
}
c.msgsPool.Put(msgs)
return 0, err
}
if n == len(m) {
@@ -328,6 +353,11 @@ func (c *PreludeBatchConn) WriteBatch(ms []ipv4.Message, flags int) (n int, err
}
start += n
}
for _, pooledBuf := range pooled {
c.bufPool.Put(pooledBuf)
}
c.msgsPool.Put(msgs)
}
return c.BatchConn.WriteBatch(ms, flags)

View File

@@ -19,13 +19,13 @@ var (
)
type readContext struct {
*FlexBuffer
FlexBuffer
*BufferPool
nextDataSize int
}
type writeContext struct {
*FlexBuffer
FlexBuffer
*BufferPool
}

42
conceal/slice_buffer.go Normal file
View File

@@ -0,0 +1,42 @@
package conceal
import "io"
type sliceReader struct {
buf []byte
}
func newSliceReader(buf []byte) sliceReader {
return sliceReader{buf: buf}
}
func (r *sliceReader) Read(p []byte) (n int, err error) {
if len(r.buf) == 0 {
return 0, io.EOF
}
n = copy(p, r.buf)
r.buf = r.buf[n:]
return n, nil
}
type sliceWriter struct {
buf []byte
}
func newSliceWriter(buf []byte) sliceWriter {
return sliceWriter{buf: buf[:0]}
}
func (w *sliceWriter) Reset(buf []byte) {
w.buf = buf[:0]
}
func (w *sliceWriter) Write(p []byte) (n int, err error) {
w.buf = append(w.buf, p...)
return len(p), nil
}
func (w *sliceWriter) Bytes() []byte {
return w.buf
}