mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2026-05-17 08:15:49 +03:00
chore: use buffer pool in ObfuscatedConn instead of the original one buf
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
package conceal
|
||||
|
||||
import "sync"
|
||||
|
||||
func NewFlexBuffer(buf []byte) *flexBuffer {
|
||||
return &flexBuffer{
|
||||
buf: buf,
|
||||
@@ -7,25 +9,31 @@ func NewFlexBuffer(buf []byte) *flexBuffer {
|
||||
}
|
||||
|
||||
type flexBuffer struct {
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func (b *flexBuffer) TempBuffer(size int) []byte {
|
||||
if size > cap(b.buf)-len(b.buf) {
|
||||
return nil
|
||||
}
|
||||
return b.buf[len(b.buf) : len(b.buf)+size]
|
||||
buf []byte
|
||||
offset int
|
||||
len int
|
||||
}
|
||||
|
||||
func (b *flexBuffer) PushTail(size int) []byte {
|
||||
oldLen := len(b.buf)
|
||||
newLen := oldLen + size
|
||||
if newLen > cap(b.buf) {
|
||||
newLen := b.len + size
|
||||
if b.offset+newLen > len(b.buf) {
|
||||
return nil
|
||||
}
|
||||
|
||||
b.buf = b.buf[:newLen]
|
||||
return b.buf[oldLen:]
|
||||
oldLen := b.len
|
||||
b.len = newLen
|
||||
return b.buf[b.offset+oldLen : b.offset+newLen]
|
||||
}
|
||||
|
||||
func (b *flexBuffer) PullTail(size int) []byte {
|
||||
newLen := b.len - size
|
||||
if newLen < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
oldLen := b.len
|
||||
b.len = newLen
|
||||
return b.buf[b.offset+newLen : b.offset+oldLen]
|
||||
}
|
||||
|
||||
func (b *flexBuffer) PullHead(size int) []byte {
|
||||
@@ -33,15 +41,29 @@ func (b *flexBuffer) PullHead(size int) []byte {
|
||||
size = len(b.buf)
|
||||
}
|
||||
|
||||
if size > len(b.buf) {
|
||||
newOffset := b.offset + size
|
||||
if newOffset+b.len > len(b.buf) {
|
||||
return nil
|
||||
}
|
||||
|
||||
pulled := b.buf[:size]
|
||||
b.buf = b.buf[size:]
|
||||
return pulled
|
||||
oldOffset := b.offset
|
||||
b.offset = newOffset
|
||||
|
||||
return b.buf[oldOffset+b.len : newOffset+b.len]
|
||||
}
|
||||
|
||||
func (b *flexBuffer) Cap() int {
|
||||
return len(b.buf)
|
||||
}
|
||||
|
||||
func (b *flexBuffer) Len() int {
|
||||
return len(b.buf)
|
||||
return b.len
|
||||
}
|
||||
|
||||
type BufferPool struct {
|
||||
sync.Pool
|
||||
}
|
||||
|
||||
func (p *BufferPool) GetBuffer() []byte {
|
||||
return p.Get().([]byte)
|
||||
}
|
||||
|
||||
@@ -2,23 +2,34 @@ package conceal
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type ObfuscatedConn struct {
|
||||
net.Conn
|
||||
obfs Obfs
|
||||
bufs BufferPool
|
||||
}
|
||||
|
||||
func NewObfuscatedConn(conn net.Conn, obfs Obfs) *ObfuscatedConn {
|
||||
return &ObfuscatedConn{
|
||||
Conn: conn,
|
||||
obfs: obfs,
|
||||
bufs: BufferPool{
|
||||
Pool: sync.Pool{
|
||||
New: func() any {
|
||||
// FIXME: put reasonable bufsize here
|
||||
return make([]byte, 2048)
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ObfuscatedConn) Read(b []byte) (n int, err error) {
|
||||
ctx := &readContext{
|
||||
flexBuffer: NewFlexBuffer(b[:0]),
|
||||
flexBuffer: NewFlexBuffer(b),
|
||||
tmpPool: &c.bufs,
|
||||
}
|
||||
for _, obf := range c.obfs {
|
||||
if err := obf.Read(c.Conn, ctx); err != nil {
|
||||
@@ -31,6 +42,7 @@ func (c *ObfuscatedConn) Read(b []byte) (n int, err error) {
|
||||
func (c *ObfuscatedConn) Write(b []byte) (n int, err error) {
|
||||
ctx := &writeContext{
|
||||
flexBuffer: NewFlexBuffer(b),
|
||||
tmpPool: &c.bufs,
|
||||
}
|
||||
for _, obf := range c.obfs {
|
||||
if err := obf.Write(c.Conn, ctx); err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"unicode"
|
||||
)
|
||||
@@ -15,17 +16,16 @@ var (
|
||||
|
||||
type readContext struct {
|
||||
*flexBuffer
|
||||
tmpPool *BufferPool
|
||||
nextDataSize int
|
||||
}
|
||||
|
||||
func (o *bytesObf) Read(reader io.Reader, ctx *readContext) error {
|
||||
buf := ctx.TempBuffer(len(o.data))
|
||||
if buf == nil {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
tmp := ctx.tmpPool.GetBuffer()
|
||||
defer ctx.tmpPool.Put(tmp)
|
||||
|
||||
_, err := io.ReadFull(reader, buf)
|
||||
if err != nil {
|
||||
buf := tmp[:len(o.data)]
|
||||
if _, err := io.ReadFull(reader, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -47,13 +47,11 @@ func (o *dataObf) Read(reader io.Reader, ctx *readContext) error {
|
||||
}
|
||||
|
||||
func (o *dataSizeObf) Read(reader io.Reader, ctx *readContext) error {
|
||||
buf := ctx.TempBuffer(o.length)
|
||||
if buf == nil {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
tmp := ctx.tmpPool.GetBuffer()
|
||||
defer ctx.tmpPool.Put(tmp)
|
||||
|
||||
_, err := io.ReadFull(reader, buf)
|
||||
if err != nil {
|
||||
buf := tmp[:o.length]
|
||||
if _, err := io.ReadFull(reader, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -68,43 +66,49 @@ func (o *dataSizeObf) Read(reader io.Reader, ctx *readContext) error {
|
||||
}
|
||||
|
||||
func (o *dataStringObf) Read(reader io.Reader, ctx *readContext) error {
|
||||
tmp := ctx.tmpPool.GetBuffer()
|
||||
defer ctx.tmpPool.Put(tmp)
|
||||
|
||||
base64len := base64.RawStdEncoding.EncodedLen(ctx.nextDataSize)
|
||||
buf := tmp[:base64len]
|
||||
if _, err := io.ReadFull(reader, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data := ctx.PushTail(ctx.nextDataSize)
|
||||
if data == nil {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
|
||||
base64len := base64.RawStdEncoding.EncodedLen(ctx.nextDataSize)
|
||||
buf := ctx.TempBuffer(base64len)
|
||||
if buf == nil {
|
||||
return io.ErrShortBuffer
|
||||
if _, err := base64.RawStdEncoding.Decode(data, buf); err != nil {
|
||||
// return buf in case of error
|
||||
ctx.PullTail(len(data))
|
||||
return fmt.Errorf("failed to decode base64: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *randObf) Read(reader io.Reader, ctx *readContext) error {
|
||||
tmp := ctx.tmpPool.GetBuffer()
|
||||
defer ctx.tmpPool.Put(tmp)
|
||||
|
||||
buf := tmp[:o.length]
|
||||
if _, err := io.ReadFull(reader, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := base64.RawStdEncoding.Decode(data, buf)
|
||||
return err
|
||||
}
|
||||
|
||||
func (o *randObf) Read(reader io.Reader, ctx *readContext) error {
|
||||
buf := ctx.TempBuffer(o.length)
|
||||
if buf == nil {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
|
||||
_, err := io.ReadFull(reader, buf)
|
||||
return err
|
||||
// I guess, there is no way to validate randomness
|
||||
// so just return nil here like everything is fine
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *randCharObf) Read(reader io.Reader, ctx *readContext) error {
|
||||
buf := ctx.TempBuffer(o.length)
|
||||
if buf == nil {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
tmp := ctx.tmpPool.GetBuffer()
|
||||
defer ctx.tmpPool.Put(tmp)
|
||||
|
||||
_, err := io.ReadFull(reader, buf)
|
||||
if err != nil {
|
||||
buf := tmp[:o.length]
|
||||
if _, err := io.ReadFull(reader, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -118,13 +122,11 @@ func (o *randCharObf) Read(reader io.Reader, ctx *readContext) error {
|
||||
}
|
||||
|
||||
func (o *randDigitObf) Read(reader io.Reader, ctx *readContext) error {
|
||||
buf := ctx.TempBuffer(o.length)
|
||||
if buf == nil {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
tmp := ctx.tmpPool.GetBuffer()
|
||||
defer ctx.tmpPool.Put(tmp)
|
||||
|
||||
_, err := io.ReadFull(reader, buf)
|
||||
if err != nil {
|
||||
buf := tmp[:o.length]
|
||||
if _, err := io.ReadFull(reader, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
type writeContext struct {
|
||||
*flexBuffer
|
||||
tmpPool *BufferPool
|
||||
}
|
||||
|
||||
func (o *bytesObf) Write(writer io.Writer, ctx *writeContext) error {
|
||||
@@ -28,12 +29,12 @@ func (o *dataObf) Write(writer io.Writer, ctx *writeContext) error {
|
||||
}
|
||||
|
||||
func (o *dataSizeObf) Write(writer io.Writer, ctx *writeContext) error {
|
||||
buf := ctx.TempBuffer(o.length)
|
||||
if buf == nil {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
tmp := ctx.tmpPool.GetBuffer()
|
||||
defer ctx.tmpPool.Put(tmp)
|
||||
|
||||
size := uint32(ctx.Len())
|
||||
buf := tmp[:o.length]
|
||||
|
||||
size := uint32(ctx.Cap())
|
||||
for i := o.length - 1; i >= 0; i-- {
|
||||
buf[i] = byte(size & 0xFF)
|
||||
size >>= 8
|
||||
@@ -49,11 +50,11 @@ func (o *dataStringObf) Write(writer io.Writer, ctx *writeContext) error {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
|
||||
tmp := ctx.tmpPool.GetBuffer()
|
||||
defer ctx.tmpPool.Put(tmp)
|
||||
|
||||
base64len := base64.RawStdEncoding.EncodedLen(len(data))
|
||||
buf := ctx.TempBuffer(base64len)
|
||||
if buf == nil {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
buf := tmp[:base64len]
|
||||
|
||||
base64.RawStdEncoding.Encode(buf, data)
|
||||
|
||||
@@ -62,11 +63,10 @@ func (o *dataStringObf) Write(writer io.Writer, ctx *writeContext) error {
|
||||
}
|
||||
|
||||
func (o *randObf) Write(writer io.Writer, ctx *writeContext) error {
|
||||
buf := ctx.TempBuffer(o.length)
|
||||
if buf == nil {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
tmp := ctx.tmpPool.GetBuffer()
|
||||
defer ctx.tmpPool.Put(tmp)
|
||||
|
||||
buf := tmp[:o.length]
|
||||
rand.Read(buf)
|
||||
|
||||
_, err := writer.Write(buf)
|
||||
@@ -76,11 +76,10 @@ func (o *randObf) Write(writer io.Writer, ctx *writeContext) error {
|
||||
const chars52 = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
func (o *randCharObf) Write(writer io.Writer, ctx *writeContext) error {
|
||||
buf := ctx.TempBuffer(o.length)
|
||||
if buf == nil {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
tmp := ctx.tmpPool.GetBuffer()
|
||||
defer ctx.tmpPool.Put(tmp)
|
||||
|
||||
buf := tmp[:o.length]
|
||||
rand.Read(buf)
|
||||
for i := range buf {
|
||||
buf[i] = chars52[buf[i]%52]
|
||||
@@ -93,11 +92,10 @@ func (o *randCharObf) Write(writer io.Writer, ctx *writeContext) error {
|
||||
const digits10 = "0123456789"
|
||||
|
||||
func (o *randDigitObf) Write(writer io.Writer, ctx *writeContext) error {
|
||||
buf := ctx.TempBuffer(o.length)
|
||||
if buf == nil {
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
tmp := ctx.tmpPool.GetBuffer()
|
||||
defer ctx.tmpPool.Put(tmp)
|
||||
|
||||
buf := tmp[:o.length]
|
||||
rand.Read(buf)
|
||||
for i := range buf {
|
||||
buf[i] = digits10[buf[i]%10]
|
||||
|
||||
@@ -149,7 +149,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||
sendBuffer = append(sendBuffer, buf)
|
||||
}
|
||||
|
||||
var buf [12 + MessageInitiationSize]byte
|
||||
var buf [MessageInitiationSize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
binary.Write(writer, binary.LittleEndian, msg)
|
||||
packet := writer.Bytes()
|
||||
|
||||
Reference in New Issue
Block a user