mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2026-05-17 08:15:49 +03:00
perf(conceal): optimize record and prelude hot paths
Replace hot-path bytes.NewBuffer usage with slice-backed readers and writers, and build FlexBuffer-backed read/write contexts by value. Remove defer from the targeted per-message and per-batch write paths, return pooled buffers explicitly after downstream writes complete, and keep TCP/UDP prelude behavior unchanged. Add focused regression tests for stream, UDP, and batch masquerade and prelude writes, and fix write-side bookkeeping to emit the actual encoded bytes.
This commit is contained in:
committed by
Yaroslav Gurov
parent
c65f05a009
commit
f4cdedf40a
@@ -178,15 +178,15 @@ func benchmarkEncodeFramedRecord(opts FramedOpts, payload []byte) []byte {
|
||||
|
||||
func benchmarkEncodeMasqueradeRecord(rules Rules, payload []byte) []byte {
|
||||
pool := benchmarkNewBufferPool()
|
||||
ctx := &writeContext{
|
||||
ctx := writeContext{
|
||||
FlexBuffer: WrapFlexBuffer(payload),
|
||||
BufferPool: WrapBufferPool(pool),
|
||||
}
|
||||
tmp := pool.Get().([]byte)
|
||||
defer pool.Put(tmp)
|
||||
|
||||
w := bytes.NewBuffer(tmp[:0])
|
||||
if err := rules.Write(w, ctx); err != nil {
|
||||
w := newSliceWriter(tmp)
|
||||
if err := rules.Write(&w, &ctx); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return append([]byte(nil), w.Bytes()...)
|
||||
|
||||
@@ -2,8 +2,8 @@ package conceal
|
||||
|
||||
import "sync"
|
||||
|
||||
func WrapFlexBuffer(buf []byte) *FlexBuffer {
|
||||
return &FlexBuffer{
|
||||
func WrapFlexBuffer(buf []byte) FlexBuffer {
|
||||
return FlexBuffer{
|
||||
buf: buf,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package conceal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
@@ -48,12 +47,12 @@ func (c *MasqueradeConn) ReadRecord(b []byte) (n int, err error) {
|
||||
return 0, ErrNoReadRecord
|
||||
}
|
||||
|
||||
ctx := &readContext{
|
||||
ctx := readContext{
|
||||
FlexBuffer: WrapFlexBuffer(b),
|
||||
BufferPool: c.pool,
|
||||
}
|
||||
|
||||
if err := c.rulesIn.Read(c.Conn, ctx); err != nil {
|
||||
if err := c.rulesIn.Read(c.Conn, &ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -65,25 +64,26 @@ func (c *MasqueradeConn) WriteRecord(b []byte) (n int, err error) {
|
||||
return 0, ErrNoWriteRecord
|
||||
}
|
||||
|
||||
ctx := &writeContext{
|
||||
ctx := writeContext{
|
||||
FlexBuffer: WrapFlexBuffer(b),
|
||||
BufferPool: c.pool,
|
||||
}
|
||||
|
||||
t := c.pool.Get()
|
||||
defer c.pool.Put(t)
|
||||
w := newSliceWriter(t)
|
||||
|
||||
w := bytes.NewBuffer(t[:0])
|
||||
|
||||
if err := c.rulesOut.Write(w, ctx); err != nil {
|
||||
if err := c.rulesOut.Write(&w, &ctx); err != nil {
|
||||
c.pool.Put(t)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if _, err := c.Conn.Write(w.Bytes()); err != nil {
|
||||
c.pool.Put(t)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return ctx.Len(), nil
|
||||
c.pool.Put(t)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *MasqueradeConn) Read(b []byte) (n int, err error) {
|
||||
@@ -126,13 +126,13 @@ func (c *MasqueradeUDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr
|
||||
return n, oobn, flags, addr, err
|
||||
}
|
||||
|
||||
r := bytes.NewBuffer(b[:n])
|
||||
ctx := &readContext{
|
||||
r := newSliceReader(b[:n])
|
||||
ctx := readContext{
|
||||
FlexBuffer: WrapFlexBuffer(b),
|
||||
BufferPool: c.pool,
|
||||
}
|
||||
|
||||
if err = c.rulesIn.Read(r, ctx); err != nil {
|
||||
if err = c.rulesIn.Read(&r, &ctx); err != nil {
|
||||
return 0, oobn, flags, addr, err
|
||||
}
|
||||
|
||||
@@ -141,19 +141,23 @@ func (c *MasqueradeUDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr
|
||||
|
||||
func (c *MasqueradeUDPConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) {
|
||||
t := c.pool.Get()
|
||||
defer c.pool.Put(t)
|
||||
|
||||
w := bytes.NewBuffer(t[:0])
|
||||
ctx := &writeContext{
|
||||
w := newSliceWriter(t)
|
||||
ctx := writeContext{
|
||||
FlexBuffer: WrapFlexBuffer(b),
|
||||
BufferPool: c.pool,
|
||||
}
|
||||
|
||||
if err = c.rulesOut.Write(w, ctx); err != nil {
|
||||
if err = c.rulesOut.Write(&w, &ctx); err != nil {
|
||||
c.pool.Put(t)
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return c.UDPConn.WriteMsgUDP(t[:ctx.Len()], oob, addr)
|
||||
n, oobn, err = c.UDPConn.WriteMsgUDP(w.Bytes(), oob, addr)
|
||||
c.pool.Put(t)
|
||||
if err != nil {
|
||||
return 0, oobn, err
|
||||
}
|
||||
return len(b), oobn, nil
|
||||
}
|
||||
|
||||
func NewMasqueradeBatchConn(conn BatchConn, bp *sync.Pool, opts MasqueradeOpts) (c *MasqueradeBatchConn, ok bool) {
|
||||
@@ -183,13 +187,13 @@ func (c *MasqueradeBatchConn) ReadBatch(ms []ipv4.Message, flags int) (n int, er
|
||||
}
|
||||
|
||||
for i := range n {
|
||||
r := bytes.NewBuffer(ms[i].Buffers[0][:ms[i].N])
|
||||
ctx := &readContext{
|
||||
r := newSliceReader(ms[i].Buffers[0][:ms[i].N])
|
||||
ctx := readContext{
|
||||
FlexBuffer: WrapFlexBuffer(ms[i].Buffers[0]),
|
||||
BufferPool: c.pool,
|
||||
}
|
||||
|
||||
if err = c.rulesIn.Read(r, ctx); err != nil {
|
||||
if err = c.rulesIn.Read(&r, &ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -200,22 +204,35 @@ func (c *MasqueradeBatchConn) ReadBatch(ms []ipv4.Message, flags int) (n int, er
|
||||
}
|
||||
|
||||
func (c *MasqueradeBatchConn) WriteBatch(ms []ipv4.Message, flags int) (n int, err error) {
|
||||
var inline [128][]byte
|
||||
pooled := inline[:0]
|
||||
if len(ms) > len(inline) {
|
||||
pooled = make([][]byte, 0, len(ms))
|
||||
}
|
||||
|
||||
for i := range ms {
|
||||
t := c.pool.Get()
|
||||
defer c.pool.Put(t)
|
||||
pooled = append(pooled, t)
|
||||
|
||||
w := bytes.NewBuffer(t[:0])
|
||||
ctx := &writeContext{
|
||||
w := newSliceWriter(t)
|
||||
ctx := writeContext{
|
||||
FlexBuffer: WrapFlexBuffer(ms[i].Buffers[0]),
|
||||
BufferPool: c.pool,
|
||||
}
|
||||
|
||||
if err = c.rulesOut.Write(w, ctx); err != nil {
|
||||
if err = c.rulesOut.Write(&w, &ctx); err != nil {
|
||||
for _, buf := range pooled {
|
||||
c.pool.Put(buf)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
ms[i].Buffers[0] = w.Bytes()
|
||||
}
|
||||
|
||||
return c.BatchConn.WriteBatch(ms, flags)
|
||||
n, err = c.BatchConn.WriteBatch(ms, flags)
|
||||
for _, buf := range pooled {
|
||||
c.pool.Put(buf)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
295
conceal/conn_optimizations_test.go
Normal file
295
conceal/conn_optimizations_test.go
Normal file
@@ -0,0 +1,295 @@
|
||||
package conceal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
func TestMasqueradeConnWriteRecordEncodesPayload(t *testing.T) {
|
||||
rules := mustTestRules(t, "<dz be 2><d>")
|
||||
raw := &recordingConn{}
|
||||
pool := newTestBufferPool()
|
||||
|
||||
conn, ok := NewMasqueradeConn(raw, pool, MasqueradeOpts{RulesOut: rules})
|
||||
if !ok {
|
||||
t.Fatal("expected masquerade conn")
|
||||
}
|
||||
|
||||
payload := []byte{0x01, 0x02, 0x03}
|
||||
n, err := conn.WriteRecord(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteRecord failed: %v", err)
|
||||
}
|
||||
if n != len(payload) {
|
||||
t.Fatalf("WriteRecord n = %d, want %d", n, len(payload))
|
||||
}
|
||||
|
||||
want := []byte{0x00, 0x03, 0x01, 0x02, 0x03}
|
||||
if !bytes.Equal(raw.writes[0], want) {
|
||||
t.Fatalf("WriteRecord bytes = %x, want %x", raw.writes[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMasqueradeUDPConnWriteMsgUDPEncodesPayload(t *testing.T) {
|
||||
rules := mustTestRules(t, "<dz be 2><d>")
|
||||
raw := &recordingUDPConn{}
|
||||
pool := newTestBufferPool()
|
||||
|
||||
conn, ok := NewMasqueradeUDPConn(raw, pool, MasqueradeOpts{RulesOut: rules})
|
||||
if !ok {
|
||||
t.Fatal("expected masquerade udp conn")
|
||||
}
|
||||
|
||||
payload := []byte{0x01, 0x02, 0x03}
|
||||
n, oobn, err := conn.WriteMsgUDP(payload, []byte{0xaa}, &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 51820})
|
||||
if err != nil {
|
||||
t.Fatalf("WriteMsgUDP failed: %v", err)
|
||||
}
|
||||
if n != len(payload) {
|
||||
t.Fatalf("WriteMsgUDP n = %d, want %d", n, len(payload))
|
||||
}
|
||||
if oobn != 1 {
|
||||
t.Fatalf("WriteMsgUDP oobn = %d, want 1", oobn)
|
||||
}
|
||||
|
||||
want := []byte{0x00, 0x03, 0x01, 0x02, 0x03}
|
||||
if len(raw.writes) != 1 {
|
||||
t.Fatalf("write count = %d, want 1", len(raw.writes))
|
||||
}
|
||||
if !bytes.Equal(raw.writes[0], want) {
|
||||
t.Fatalf("WriteMsgUDP bytes = %x, want %x", raw.writes[0], want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMasqueradeBatchConnWriteBatchEncodesEachMessage(t *testing.T) {
|
||||
rules := mustTestRules(t, "<dz be 2><d>")
|
||||
raw := &recordingBatchConn{}
|
||||
pool := newTestBufferPool()
|
||||
|
||||
conn, ok := NewMasqueradeBatchConn(raw, pool, MasqueradeOpts{RulesOut: rules})
|
||||
if !ok {
|
||||
t.Fatal("expected masquerade batch conn")
|
||||
}
|
||||
|
||||
msgs := []ipv4.Message{
|
||||
{Buffers: net.Buffers{[]byte{0x01, 0x02}}},
|
||||
{Buffers: net.Buffers{[]byte{0x03}}},
|
||||
}
|
||||
n, err := conn.WriteBatch(msgs, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteBatch failed: %v", err)
|
||||
}
|
||||
if n != len(msgs) {
|
||||
t.Fatalf("WriteBatch n = %d, want %d", n, len(msgs))
|
||||
}
|
||||
if len(raw.batches) != 1 {
|
||||
t.Fatalf("batch count = %d, want 1", len(raw.batches))
|
||||
}
|
||||
|
||||
want := [][]byte{
|
||||
{0x00, 0x02, 0x01, 0x02},
|
||||
{0x00, 0x01, 0x03},
|
||||
}
|
||||
for i, got := range raw.batches[0] {
|
||||
if !bytes.Equal(got.data, want[i]) {
|
||||
t.Fatalf("batch msg %d = %x, want %x", i, got.data, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreludeBatchConnWriteBatchEmitsPreludeBeforeInitiation(t *testing.T) {
|
||||
raw := &recordingBatchConn{}
|
||||
pool := newTestBufferPool()
|
||||
msgsPool := newTestMsgsPool()
|
||||
rules := mustTestRules(t, "<b 0xaabb>")
|
||||
|
||||
conn, ok := NewPreludeBatchConn(raw, raw, pool, msgsPool, nil, PreludeOpts{
|
||||
Jc: 1,
|
||||
Jmin: 3,
|
||||
Jmax: 3,
|
||||
RulesArr: [5]Rules{
|
||||
rules,
|
||||
},
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("expected prelude batch conn")
|
||||
}
|
||||
|
||||
initiation := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint32(initiation[:4], WireguardMsgInitiationType)
|
||||
msgs := []ipv4.Message{
|
||||
{
|
||||
Buffers: net.Buffers{initiation},
|
||||
OOB: []byte{0x44},
|
||||
Addr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 51820},
|
||||
},
|
||||
}
|
||||
|
||||
n, err := conn.WriteBatch(msgs, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteBatch failed: %v", err)
|
||||
}
|
||||
if n != len(msgs) {
|
||||
t.Fatalf("WriteBatch n = %d, want %d", n, len(msgs))
|
||||
}
|
||||
if len(raw.batches) != 2 {
|
||||
t.Fatalf("batch count = %d, want 2", len(raw.batches))
|
||||
}
|
||||
|
||||
preludeBatch := raw.batches[0]
|
||||
if len(preludeBatch) != 2 {
|
||||
t.Fatalf("prelude batch len = %d, want 2", len(preludeBatch))
|
||||
}
|
||||
if !bytes.Equal(preludeBatch[0].data, []byte{0xaa, 0xbb}) {
|
||||
t.Fatalf("prelude decoy = %x, want aabb", preludeBatch[0].data)
|
||||
}
|
||||
if len(preludeBatch[1].data) != 3 {
|
||||
t.Fatalf("junk len = %d, want 3", len(preludeBatch[1].data))
|
||||
}
|
||||
if !bytes.Equal(raw.batches[1][0].data, initiation) {
|
||||
t.Fatalf("main batch payload changed")
|
||||
}
|
||||
}
|
||||
|
||||
func mustTestRules(t *testing.T, spec string) Rules {
|
||||
t.Helper()
|
||||
|
||||
rules, err := ParseRules(spec)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseRules(%q): %v", spec, err)
|
||||
}
|
||||
return rules
|
||||
}
|
||||
|
||||
func newTestBufferPool() *sync.Pool {
|
||||
return &sync.Pool{
|
||||
New: func() any {
|
||||
return make([]byte, 256)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newTestMsgsPool() *sync.Pool {
|
||||
return &sync.Pool{
|
||||
New: func() any {
|
||||
msgs := make([]ipv4.Message, 8)
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make(net.Buffers, 1)
|
||||
}
|
||||
return &msgs
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type recordingConn struct {
|
||||
writes [][]byte
|
||||
}
|
||||
|
||||
func (c *recordingConn) Read(_ []byte) (int, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (c *recordingConn) Write(b []byte) (int, error) {
|
||||
c.writes = append(c.writes, append([]byte(nil), b...))
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *recordingConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *recordingConn) LocalAddr() net.Addr {
|
||||
return &net.TCPAddr{}
|
||||
}
|
||||
|
||||
func (c *recordingConn) RemoteAddr() net.Addr {
|
||||
return &net.TCPAddr{}
|
||||
}
|
||||
|
||||
func (c *recordingConn) SetDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *recordingConn) SetReadDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *recordingConn) SetWriteDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type recordingUDPConn struct {
|
||||
writes [][]byte
|
||||
}
|
||||
|
||||
func (c *recordingUDPConn) ReadFrom([]byte) (int, net.Addr, error) {
|
||||
return 0, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (c *recordingUDPConn) WriteTo([]byte, net.Addr) (int, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (c *recordingUDPConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *recordingUDPConn) LocalAddr() net.Addr {
|
||||
return &net.UDPAddr{}
|
||||
}
|
||||
|
||||
func (c *recordingUDPConn) SetDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *recordingUDPConn) SetReadDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *recordingUDPConn) SetWriteDeadline(time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *recordingUDPConn) ReadMsgUDP(_, _ []byte) (int, int, int, *net.UDPAddr, error) {
|
||||
return 0, 0, 0, nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (c *recordingUDPConn) WriteMsgUDP(b, _ []byte, _ *net.UDPAddr) (int, int, error) {
|
||||
c.writes = append(c.writes, append([]byte(nil), b...))
|
||||
return len(b), 1, nil
|
||||
}
|
||||
|
||||
func (c *recordingUDPConn) SyscallConn() (syscall.RawConn, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
type recordedBatchMessage struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
type recordingBatchConn struct {
|
||||
batches [][]recordedBatchMessage
|
||||
}
|
||||
|
||||
func (c *recordingBatchConn) ReadBatch([]ipv4.Message, int) (int, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (c *recordingBatchConn) WriteBatch(ms []ipv4.Message, flags int) (int, error) {
|
||||
batch := make([]recordedBatchMessage, len(ms))
|
||||
for i := range ms {
|
||||
batch[i] = recordedBatchMessage{
|
||||
data: append([]byte(nil), ms[i].Buffers[0]...),
|
||||
}
|
||||
}
|
||||
c.batches = append(c.batches, batch)
|
||||
return len(ms), nil
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package conceal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"math/big"
|
||||
@@ -103,35 +102,39 @@ func (c *PreludeUDPConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn
|
||||
}
|
||||
|
||||
if isInit {
|
||||
b := c.pool.Get()
|
||||
defer c.pool.Put(b)
|
||||
|
||||
ctx := &writeContext{
|
||||
buf := c.pool.Get()
|
||||
ctx := writeContext{
|
||||
FlexBuffer: WrapFlexBuffer(nil),
|
||||
BufferPool: c.pool,
|
||||
}
|
||||
w := newSliceWriter(buf)
|
||||
|
||||
for _, rules := range c.rulesArr {
|
||||
if rules == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
w := bytes.NewBuffer(b[:0])
|
||||
if err = rules.Write(w, ctx); err != nil {
|
||||
w.Reset(buf)
|
||||
if err = rules.Write(&w, &ctx); err != nil {
|
||||
c.pool.Put(buf)
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
if _, _, err = c.origin.WriteMsgUDP(w.Bytes(), oob, addr); err != nil {
|
||||
c.pool.Put(buf)
|
||||
return 0, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
for range c.junkCount {
|
||||
junk := c.junkGen.generate(b)
|
||||
junk := c.junkGen.generate(buf)
|
||||
if _, _, err = c.origin.WriteMsgUDP(junk, oob, addr); err != nil {
|
||||
c.pool.Put(buf)
|
||||
return 0, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
c.pool.Put(buf)
|
||||
}
|
||||
|
||||
return c.UDPConn.WriteMsgUDP(b, oob, addr)
|
||||
@@ -194,28 +197,30 @@ func (c *PreludeConn) Write(b []byte) (n int, err error) {
|
||||
|
||||
func (c *PreludeConn) writePreludeRecords() (err error) {
|
||||
buf := c.pool.Get()
|
||||
defer c.pool.Put(buf)
|
||||
|
||||
ctx := &writeContext{
|
||||
ctx := writeContext{
|
||||
FlexBuffer: WrapFlexBuffer(nil),
|
||||
BufferPool: c.pool,
|
||||
}
|
||||
w := newSliceWriter(buf)
|
||||
|
||||
for _, rules := range c.rulesArr {
|
||||
if rules == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
w := bytes.NewBuffer(buf[:0])
|
||||
if err = rules.Write(w, ctx); err != nil {
|
||||
w.Reset(buf)
|
||||
if err = rules.Write(&w, &ctx); err != nil {
|
||||
c.pool.Put(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = c.StreamRecordConn.WriteRecord(w.Bytes()); err != nil {
|
||||
c.pool.Put(buf)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
c.pool.Put(buf)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -278,13 +283,25 @@ func (c *PreludeBatchConn) WriteBatch(ms []ipv4.Message, flags int) (n int, err
|
||||
}
|
||||
|
||||
if initMsg != nil {
|
||||
ctx := &writeContext{
|
||||
ctx := writeContext{
|
||||
FlexBuffer: WrapFlexBuffer(nil),
|
||||
BufferPool: c.bufPool,
|
||||
}
|
||||
|
||||
msgs := c.msgsPool.Get().(*[]ipv4.Message)
|
||||
defer c.msgsPool.Put(msgs)
|
||||
count := c.junkCount
|
||||
for _, rules := range c.rulesArr {
|
||||
if rules != nil {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
var inline [32][]byte
|
||||
pooled := inline[:0]
|
||||
if count > len(inline) {
|
||||
pooled = make([][]byte, 0, count)
|
||||
}
|
||||
|
||||
i := 0
|
||||
|
||||
for _, rules := range c.rulesArr {
|
||||
@@ -293,10 +310,14 @@ func (c *PreludeBatchConn) WriteBatch(ms []ipv4.Message, flags int) (n int, err
|
||||
}
|
||||
|
||||
buf := c.bufPool.Get()
|
||||
defer c.bufPool.Put(buf)
|
||||
pooled = append(pooled, buf)
|
||||
|
||||
w := bytes.NewBuffer(buf[:0])
|
||||
if err = rules.Write(w, ctx); err != nil {
|
||||
w := newSliceWriter(buf)
|
||||
if err = rules.Write(&w, &ctx); err != nil {
|
||||
for _, pooledBuf := range pooled {
|
||||
c.bufPool.Put(pooledBuf)
|
||||
}
|
||||
c.msgsPool.Put(msgs)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -308,7 +329,7 @@ func (c *PreludeBatchConn) WriteBatch(ms []ipv4.Message, flags int) (n int, err
|
||||
|
||||
for range c.junkCount {
|
||||
buf := c.bufPool.Get()
|
||||
defer c.bufPool.Put(buf)
|
||||
pooled = append(pooled, buf)
|
||||
|
||||
(*msgs)[i].Buffers[0] = c.junkGen.generate(buf)
|
||||
(*msgs)[i].OOB = initMsg.OOB
|
||||
@@ -321,6 +342,10 @@ func (c *PreludeBatchConn) WriteBatch(ms []ipv4.Message, flags int) (n int, err
|
||||
m := (*msgs)[start:i]
|
||||
n, err = c.origin.WriteBatch(m, flags)
|
||||
if err != nil {
|
||||
for _, pooledBuf := range pooled {
|
||||
c.bufPool.Put(pooledBuf)
|
||||
}
|
||||
c.msgsPool.Put(msgs)
|
||||
return 0, err
|
||||
}
|
||||
if n == len(m) {
|
||||
@@ -328,6 +353,11 @@ func (c *PreludeBatchConn) WriteBatch(ms []ipv4.Message, flags int) (n int, err
|
||||
}
|
||||
start += n
|
||||
}
|
||||
|
||||
for _, pooledBuf := range pooled {
|
||||
c.bufPool.Put(pooledBuf)
|
||||
}
|
||||
c.msgsPool.Put(msgs)
|
||||
}
|
||||
|
||||
return c.BatchConn.WriteBatch(ms, flags)
|
||||
|
||||
@@ -19,13 +19,13 @@ var (
|
||||
)
|
||||
|
||||
type readContext struct {
|
||||
*FlexBuffer
|
||||
FlexBuffer
|
||||
*BufferPool
|
||||
nextDataSize int
|
||||
}
|
||||
|
||||
type writeContext struct {
|
||||
*FlexBuffer
|
||||
FlexBuffer
|
||||
*BufferPool
|
||||
}
|
||||
|
||||
|
||||
42
conceal/slice_buffer.go
Normal file
42
conceal/slice_buffer.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package conceal
|
||||
|
||||
import "io"
|
||||
|
||||
type sliceReader struct {
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func newSliceReader(buf []byte) sliceReader {
|
||||
return sliceReader{buf: buf}
|
||||
}
|
||||
|
||||
func (r *sliceReader) Read(p []byte) (n int, err error) {
|
||||
if len(r.buf) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n = copy(p, r.buf)
|
||||
r.buf = r.buf[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
type sliceWriter struct {
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func newSliceWriter(buf []byte) sliceWriter {
|
||||
return sliceWriter{buf: buf[:0]}
|
||||
}
|
||||
|
||||
func (w *sliceWriter) Reset(buf []byte) {
|
||||
w.buf = buf[:0]
|
||||
}
|
||||
|
||||
func (w *sliceWriter) Write(p []byte) (n int, err error) {
|
||||
w.buf = append(w.buf, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *sliceWriter) Bytes() []byte {
|
||||
return w.buf
|
||||
}
|
||||
Reference in New Issue
Block a user