chore: use buffer pool in ObfuscatedConn instead of the original one buf

This commit is contained in:
Yaroslav Gurov
2025-12-20 08:52:36 +00:00
parent bb8270bcfa
commit fecfe69b7b
5 changed files with 115 additions and 81 deletions

View File

@@ -1,5 +1,7 @@
package conceal
import "sync"
func NewFlexBuffer(buf []byte) *flexBuffer {
return &flexBuffer{
buf: buf,
@@ -7,25 +9,31 @@ func NewFlexBuffer(buf []byte) *flexBuffer {
}
type flexBuffer struct {
buf []byte
}
func (b *flexBuffer) TempBuffer(size int) []byte {
if size > cap(b.buf)-len(b.buf) {
return nil
}
return b.buf[len(b.buf) : len(b.buf)+size]
buf []byte
offset int
len int
}
func (b *flexBuffer) PushTail(size int) []byte {
oldLen := len(b.buf)
newLen := oldLen + size
if newLen > cap(b.buf) {
newLen := b.len + size
if b.offset+newLen > len(b.buf) {
return nil
}
b.buf = b.buf[:newLen]
return b.buf[oldLen:]
oldLen := b.len
b.len = newLen
return b.buf[b.offset+oldLen : b.offset+newLen]
}
func (b *flexBuffer) PullTail(size int) []byte {
newLen := b.len - size
if newLen < 0 {
return nil
}
oldLen := b.len
b.len = newLen
return b.buf[b.offset+newLen : b.offset+oldLen]
}
func (b *flexBuffer) PullHead(size int) []byte {
@@ -33,15 +41,29 @@ func (b *flexBuffer) PullHead(size int) []byte {
size = len(b.buf)
}
if size > len(b.buf) {
newOffset := b.offset + size
if newOffset+b.len > len(b.buf) {
return nil
}
pulled := b.buf[:size]
b.buf = b.buf[size:]
return pulled
oldOffset := b.offset
b.offset = newOffset
return b.buf[oldOffset+b.len : newOffset+b.len]
}
func (b *flexBuffer) Cap() int {
return len(b.buf)
}
func (b *flexBuffer) Len() int {
return len(b.buf)
return b.len
}
type BufferPool struct {
sync.Pool
}
func (p *BufferPool) GetBuffer() []byte {
return p.Get().([]byte)
}

View File

@@ -2,23 +2,34 @@ package conceal
import (
"net"
"sync"
)
type ObfuscatedConn struct {
net.Conn
obfs Obfs
bufs BufferPool
}
func NewObfuscatedConn(conn net.Conn, obfs Obfs) *ObfuscatedConn {
return &ObfuscatedConn{
Conn: conn,
obfs: obfs,
bufs: BufferPool{
Pool: sync.Pool{
New: func() any {
// FIXME: put reasonable bufsize here
return make([]byte, 2048)
},
},
},
}
}
func (c *ObfuscatedConn) Read(b []byte) (n int, err error) {
ctx := &readContext{
flexBuffer: NewFlexBuffer(b[:0]),
flexBuffer: NewFlexBuffer(b),
tmpPool: &c.bufs,
}
for _, obf := range c.obfs {
if err := obf.Read(c.Conn, ctx); err != nil {
@@ -31,6 +42,7 @@ func (c *ObfuscatedConn) Read(b []byte) (n int, err error) {
func (c *ObfuscatedConn) Write(b []byte) (n int, err error) {
ctx := &writeContext{
flexBuffer: NewFlexBuffer(b),
tmpPool: &c.bufs,
}
for _, obf := range c.obfs {
if err := obf.Write(c.Conn, ctx); err != nil {

View File

@@ -5,6 +5,7 @@ import (
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
"unicode"
)
@@ -15,17 +16,16 @@ var (
type readContext struct {
*flexBuffer
tmpPool *BufferPool
nextDataSize int
}
func (o *bytesObf) Read(reader io.Reader, ctx *readContext) error {
buf := ctx.TempBuffer(len(o.data))
if buf == nil {
return io.ErrShortBuffer
}
tmp := ctx.tmpPool.GetBuffer()
defer ctx.tmpPool.Put(tmp)
_, err := io.ReadFull(reader, buf)
if err != nil {
buf := tmp[:len(o.data)]
if _, err := io.ReadFull(reader, buf); err != nil {
return err
}
@@ -47,13 +47,11 @@ func (o *dataObf) Read(reader io.Reader, ctx *readContext) error {
}
func (o *dataSizeObf) Read(reader io.Reader, ctx *readContext) error {
buf := ctx.TempBuffer(o.length)
if buf == nil {
return io.ErrShortBuffer
}
tmp := ctx.tmpPool.GetBuffer()
defer ctx.tmpPool.Put(tmp)
_, err := io.ReadFull(reader, buf)
if err != nil {
buf := tmp[:o.length]
if _, err := io.ReadFull(reader, buf); err != nil {
return err
}
@@ -68,43 +66,49 @@ func (o *dataSizeObf) Read(reader io.Reader, ctx *readContext) error {
}
func (o *dataStringObf) Read(reader io.Reader, ctx *readContext) error {
tmp := ctx.tmpPool.GetBuffer()
defer ctx.tmpPool.Put(tmp)
base64len := base64.RawStdEncoding.EncodedLen(ctx.nextDataSize)
buf := tmp[:base64len]
if _, err := io.ReadFull(reader, buf); err != nil {
return err
}
data := ctx.PushTail(ctx.nextDataSize)
if data == nil {
return io.ErrShortBuffer
}
base64len := base64.RawStdEncoding.EncodedLen(ctx.nextDataSize)
buf := ctx.TempBuffer(base64len)
if buf == nil {
return io.ErrShortBuffer
if _, err := base64.RawStdEncoding.Decode(data, buf); err != nil {
// return buf in case of error
ctx.PullTail(len(data))
return fmt.Errorf("failed to decode base64: %w", err)
}
return nil
}
func (o *randObf) Read(reader io.Reader, ctx *readContext) error {
tmp := ctx.tmpPool.GetBuffer()
defer ctx.tmpPool.Put(tmp)
buf := tmp[:o.length]
if _, err := io.ReadFull(reader, buf); err != nil {
return err
}
_, err := base64.RawStdEncoding.Decode(data, buf)
return err
}
func (o *randObf) Read(reader io.Reader, ctx *readContext) error {
buf := ctx.TempBuffer(o.length)
if buf == nil {
return io.ErrShortBuffer
}
_, err := io.ReadFull(reader, buf)
return err
// I guess, there is no way to validate randomness
// so just return nil here like everything is fine
return nil
}
func (o *randCharObf) Read(reader io.Reader, ctx *readContext) error {
buf := ctx.TempBuffer(o.length)
if buf == nil {
return io.ErrShortBuffer
}
tmp := ctx.tmpPool.GetBuffer()
defer ctx.tmpPool.Put(tmp)
_, err := io.ReadFull(reader, buf)
if err != nil {
buf := tmp[:o.length]
if _, err := io.ReadFull(reader, buf); err != nil {
return err
}
@@ -118,13 +122,11 @@ func (o *randCharObf) Read(reader io.Reader, ctx *readContext) error {
}
func (o *randDigitObf) Read(reader io.Reader, ctx *readContext) error {
buf := ctx.TempBuffer(o.length)
if buf == nil {
return io.ErrShortBuffer
}
tmp := ctx.tmpPool.GetBuffer()
defer ctx.tmpPool.Put(tmp)
_, err := io.ReadFull(reader, buf)
if err != nil {
buf := tmp[:o.length]
if _, err := io.ReadFull(reader, buf); err != nil {
return err
}

View File

@@ -10,6 +10,7 @@ import (
type writeContext struct {
*flexBuffer
tmpPool *BufferPool
}
func (o *bytesObf) Write(writer io.Writer, ctx *writeContext) error {
@@ -28,12 +29,12 @@ func (o *dataObf) Write(writer io.Writer, ctx *writeContext) error {
}
func (o *dataSizeObf) Write(writer io.Writer, ctx *writeContext) error {
buf := ctx.TempBuffer(o.length)
if buf == nil {
return io.ErrShortBuffer
}
tmp := ctx.tmpPool.GetBuffer()
defer ctx.tmpPool.Put(tmp)
size := uint32(ctx.Len())
buf := tmp[:o.length]
size := uint32(ctx.Cap())
for i := o.length - 1; i >= 0; i-- {
buf[i] = byte(size & 0xFF)
size >>= 8
@@ -49,11 +50,11 @@ func (o *dataStringObf) Write(writer io.Writer, ctx *writeContext) error {
return io.ErrShortBuffer
}
tmp := ctx.tmpPool.GetBuffer()
defer ctx.tmpPool.Put(tmp)
base64len := base64.RawStdEncoding.EncodedLen(len(data))
buf := ctx.TempBuffer(base64len)
if buf == nil {
return io.ErrShortBuffer
}
buf := tmp[:base64len]
base64.RawStdEncoding.Encode(buf, data)
@@ -62,11 +63,10 @@ func (o *dataStringObf) Write(writer io.Writer, ctx *writeContext) error {
}
func (o *randObf) Write(writer io.Writer, ctx *writeContext) error {
buf := ctx.TempBuffer(o.length)
if buf == nil {
return io.ErrShortBuffer
}
tmp := ctx.tmpPool.GetBuffer()
defer ctx.tmpPool.Put(tmp)
buf := tmp[:o.length]
rand.Read(buf)
_, err := writer.Write(buf)
@@ -76,11 +76,10 @@ func (o *randObf) Write(writer io.Writer, ctx *writeContext) error {
const chars52 = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func (o *randCharObf) Write(writer io.Writer, ctx *writeContext) error {
buf := ctx.TempBuffer(o.length)
if buf == nil {
return io.ErrShortBuffer
}
tmp := ctx.tmpPool.GetBuffer()
defer ctx.tmpPool.Put(tmp)
buf := tmp[:o.length]
rand.Read(buf)
for i := range buf {
buf[i] = chars52[buf[i]%52]
@@ -93,11 +92,10 @@ func (o *randCharObf) Write(writer io.Writer, ctx *writeContext) error {
const digits10 = "0123456789"
func (o *randDigitObf) Write(writer io.Writer, ctx *writeContext) error {
buf := ctx.TempBuffer(o.length)
if buf == nil {
return io.ErrShortBuffer
}
tmp := ctx.tmpPool.GetBuffer()
defer ctx.tmpPool.Put(tmp)
buf := tmp[:o.length]
rand.Read(buf)
for i := range buf {
buf[i] = digits10[buf[i]%10]

View File

@@ -149,7 +149,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
sendBuffer = append(sendBuffer, buf)
}
var buf [12 + MessageInitiationSize]byte
var buf [MessageInitiationSize]byte
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()