fix: refactor processing of junk packets (#103)

- fix the bug that transport packet interprets as init/resp/cookie with the same size
- cleanup error responses
- reduce buffer allocations
This commit is contained in:
Yaroslav Gurov
2025-12-01 13:07:48 +01:00
committed by GitHub
parent f6542209f4
commit 0361c54dca
33 changed files with 852 additions and 2832 deletions

View File

@@ -1,90 +0,0 @@
package awg
import (
"bytes"
"fmt"
"sync"
"github.com/tevino/abool"
)
type Cfg struct {
IsSet bool
JunkPacketCount int
JunkPacketMinSize int
JunkPacketMaxSize int
InitHeaderJunkSize int
ResponseHeaderJunkSize int
CookieReplyHeaderJunkSize int
TransportHeaderJunkSize int
MagicHeaders MagicHeaders
}
type Protocol struct {
IsOn abool.AtomicBool
// TODO: revision the need of the mutex
Mux sync.RWMutex
Cfg Cfg
JunkCreator JunkCreator
HandshakeHandler SpecialHandshakeHandler
}
func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) {
protocol.Mux.RLock()
defer protocol.Mux.RUnlock()
return protocol.createHeaderJunk(protocol.Cfg.InitHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) {
protocol.Mux.RLock()
defer protocol.Mux.RUnlock()
return protocol.createHeaderJunk(protocol.Cfg.ResponseHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) {
protocol.Mux.RLock()
defer protocol.Mux.RUnlock()
return protocol.createHeaderJunk(protocol.Cfg.CookieReplyHeaderJunkSize, 0)
}
func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) {
protocol.Mux.RLock()
defer protocol.Mux.RUnlock()
return protocol.createHeaderJunk(protocol.Cfg.TransportHeaderJunkSize, packetSize)
}
func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, error) {
if junkSize == 0 {
return nil, nil
}
buf := make([]byte, 0, junkSize+extraSize)
writer := bytes.NewBuffer(buf[:0])
err := protocol.JunkCreator.AppendJunk(writer, junkSize)
if err != nil {
return nil, fmt.Errorf("append junk: %w", err)
}
return writer.Bytes(), nil
}
func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) {
for _, magicHeader := range protocol.Cfg.MagicHeaders.Values {
if magicHeader.Min <= msgType && msgType <= magicHeader.Max {
return magicHeader.Min, nil
}
}
return 0, fmt.Errorf("no header for value: %d", msgType)
}
func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) {
return protocol.Cfg.MagicHeaders.Get(defaultMsgType)
}

View File

@@ -1,37 +0,0 @@
package internal
type mockGenerator struct {
size int
}
func NewMockGenerator(size int) mockGenerator {
return mockGenerator{size: size}
}
func (m mockGenerator) Generate() []byte {
return make([]byte, m.size)
}
func (m mockGenerator) Size() int {
return m.size
}
func (m mockGenerator) Name() string {
return "mock"
}
type mockByteGenerator struct {
data []byte
}
func NewMockByteGenerator(data []byte) mockByteGenerator {
return mockByteGenerator{data: data}
}
func (bg mockByteGenerator) Generate() []byte {
return bg.data
}
func (bg mockByteGenerator) Size() int {
return len(bg.data)
}

View File

@@ -1,50 +0,0 @@
package awg
import (
"bytes"
"fmt"
)
type JunkCreator struct {
cfg Cfg
randomGenerator PRNG[int]
}
// TODO: refactor param to only pass the junk related params
func NewJunkCreator(cfg Cfg) JunkCreator {
return JunkCreator{cfg: cfg, randomGenerator: NewPRNG[int]()}
}
// Should be called with awg mux RLocked
func (jc *JunkCreator) CreateJunkPackets(junks *[][]byte) {
if jc.cfg.JunkPacketCount == 0 {
return
}
for range jc.cfg.JunkPacketCount {
packetSize := jc.randomPacketSize()
junk := jc.randomJunkWithSize(packetSize)
*junks = append(*junks, junk)
}
return
}
// Should be called with awg mux RLocked
func (jc *JunkCreator) randomPacketSize() int {
return jc.randomGenerator.RandomSizeInRange(jc.cfg.JunkPacketMinSize, jc.cfg.JunkPacketMaxSize)
}
// Should be called with awg mux RLocked
func (jc *JunkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
headerJunk := jc.randomJunkWithSize(size)
_, err := writer.Write(headerJunk)
if err != nil {
return fmt.Errorf("write header junk: %v", err)
}
return nil
}
// Should be called with awg mux RLocked
func (jc *JunkCreator) randomJunkWithSize(size int) []byte {
return jc.randomGenerator.ReadSize(size)
}

View File

@@ -1,97 +0,0 @@
package awg
import (
"bytes"
"fmt"
"testing"
)
func setUpJunkCreator() JunkCreator {
mh, _ := NewMagicHeaders(
[]MagicHeader{
NewMagicHeaderSameValue(123456),
NewMagicHeaderSameValue(67543),
NewMagicHeaderSameValue(32345),
NewMagicHeaderSameValue(123123),
},
)
jc := NewJunkCreator(Cfg{
IsSet: true,
JunkPacketCount: 5,
JunkPacketMinSize: 500,
JunkPacketMaxSize: 1000,
InitHeaderJunkSize: 30,
ResponseHeaderJunkSize: 40,
MagicHeaders: mh,
})
return jc
}
func Test_junkCreator_createJunkPackets(t *testing.T) {
jc := setUpJunkCreator()
t.Run("valid", func(t *testing.T) {
got := make([][]byte, 0, jc.cfg.JunkPacketCount)
jc.CreateJunkPackets(&got)
seen := make(map[string]bool)
for _, junk := range got {
key := string(junk)
if seen[key] {
t.Errorf(
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
got,
junk,
)
return
}
seen[key] = true
}
})
}
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
t.Run("valid", func(t *testing.T) {
jc := setUpJunkCreator()
r1 := jc.randomJunkWithSize(10)
r2 := jc.randomJunkWithSize(10)
fmt.Printf("%v\n%v\n", r1, r2)
if bytes.Equal(r1, r2) {
t.Errorf("same junks")
return
}
})
}
func Test_junkCreator_randomPacketSize(t *testing.T) {
jc := setUpJunkCreator()
for range [30]struct{}{} {
t.Run("valid", func(t *testing.T) {
if got := jc.randomPacketSize(); jc.cfg.JunkPacketMinSize > got ||
got > jc.cfg.JunkPacketMaxSize {
t.Errorf(
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
got,
jc.cfg.JunkPacketMinSize,
jc.cfg.JunkPacketMaxSize,
)
}
})
}
}
func Test_junkCreator_appendJunk(t *testing.T) {
jc := setUpJunkCreator()
t.Run("valid", func(t *testing.T) {
s := "apple"
buffer := bytes.NewBuffer([]byte(s))
err := jc.AppendJunk(buffer, 30)
if err != nil &&
buffer.Len() != len(s)+30 {
t.Error("appendWithJunk() size don't match")
}
read := make([]byte, 50)
buffer.Read(read)
fmt.Println(string(read))
})
}

View File

@@ -1,97 +0,0 @@
package awg
import (
"cmp"
"fmt"
"slices"
"strconv"
"strings"
)
type MagicHeader struct {
Min uint32
Max uint32
}
func NewMagicHeaderSameValue(value uint32) MagicHeader {
return MagicHeader{Min: value, Max: value}
}
func NewMagicHeader(min, max uint32) (MagicHeader, error) {
if min > max {
return MagicHeader{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
}
return MagicHeader{Min: min, Max: max}, nil
}
func ParseMagicHeader(key, value string) (MagicHeader, error) {
hyphenIdx := strings.Index(value, "-")
if hyphenIdx == -1 {
// if there is no hyphen, we treat it as single magic header value
magicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return MagicHeader{}, fmt.Errorf("parse key: %s; value: %s; %w", key, value, err)
}
return NewMagicHeader(uint32(magicHeader), uint32(magicHeader))
}
minStr := value[:hyphenIdx]
maxStr := value[hyphenIdx+1:]
if len(minStr) == 0 || len(maxStr) == 0 {
return MagicHeader{}, fmt.Errorf("invalid value for key: %s; value: %s; expected format: min-max", key, value)
}
min, err := strconv.ParseUint(minStr, 10, 32)
if err != nil {
return MagicHeader{}, fmt.Errorf("parse min key: %s; value: %s; %w", key, minStr, err)
}
max, err := strconv.ParseUint(maxStr, 10, 32)
if err != nil {
return MagicHeader{}, fmt.Errorf("parse max key: %s; value: %s; %w", key, maxStr, err)
}
magicHeader, err := NewMagicHeader(uint32(min), uint32(max))
if err != nil {
return MagicHeader{}, fmt.Errorf("new magicHeader key: %s; value: %s-%s; %w", key, minStr, maxStr, err)
}
return magicHeader, nil
}
type MagicHeaders struct {
Values []MagicHeader
randomGenerator RandomNumberGenerator[uint32]
}
func NewMagicHeaders(headerValues []MagicHeader) (MagicHeaders, error) {
if len(headerValues) != 4 {
return MagicHeaders{}, fmt.Errorf("all header types should be included: %v", headerValues)
}
sortedMagicHeaders := slices.SortedFunc(slices.Values(headerValues), func(lhs MagicHeader, rhs MagicHeader) int {
return cmp.Compare(lhs.Min, rhs.Min)
})
for i := range 3 {
if sortedMagicHeaders[i].Max >= sortedMagicHeaders[i+1].Min {
return MagicHeaders{}, fmt.Errorf(
"magic headers shouldn't overlap; %v > %v",
sortedMagicHeaders[i].Max,
sortedMagicHeaders[i+1].Min,
)
}
}
return MagicHeaders{Values: headerValues, randomGenerator: NewPRNG[uint32]()}, nil
}
func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
if defaultMsgType == 0 || defaultMsgType > 4 {
return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType)
}
return mh.randomGenerator.RandomSizeInRange(mh.Values[defaultMsgType-1].Min, mh.Values[defaultMsgType-1].Max), nil
}

View File

@@ -1,488 +0,0 @@
package awg
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNewMagicHeaderSameValue(t *testing.T) {
tests := []struct {
name string
value uint32
expected MagicHeader
}{
{
name: "zero value",
value: 0,
expected: MagicHeader{Min: 0, Max: 0},
},
{
name: "small value",
value: 1,
expected: MagicHeader{Min: 1, Max: 1},
},
{
name: "large value",
value: 4294967295, // max uint32
expected: MagicHeader{Min: 4294967295, Max: 4294967295},
},
{
name: "medium value",
value: 1000,
expected: MagicHeader{Min: 1000, Max: 1000},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result := NewMagicHeaderSameValue(tt.value)
require.Equal(t, tt.expected, result)
})
}
}
func TestNewMagicHeader(t *testing.T) {
tests := []struct {
name string
min uint32
max uint32
expected MagicHeader
errorMsg string
}{
{
name: "valid range",
min: 1,
max: 10,
expected: MagicHeader{Min: 1, Max: 10},
},
{
name: "equal values",
min: 5,
max: 5,
expected: MagicHeader{Min: 5, Max: 5},
},
{
name: "zero range",
min: 0,
max: 0,
expected: MagicHeader{Min: 0, Max: 0},
},
{
name: "max uint32 range",
min: 4294967294,
max: 4294967295,
expected: MagicHeader{Min: 4294967294, Max: 4294967295},
},
{
name: "min greater than max",
min: 10,
max: 5,
expected: MagicHeader{},
errorMsg: "min (10) cannot be greater than max (5)",
},
{
name: "large min greater than max",
min: 4294967295,
max: 1,
expected: MagicHeader{},
errorMsg: "min (4294967295) cannot be greater than max (1)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result, err := NewMagicHeader(tt.min, tt.max)
if tt.errorMsg != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errorMsg)
require.Equal(t, MagicHeader{}, result)
} else {
require.NoError(t, err)
require.Equal(t, tt.expected, result)
}
})
}
}
func TestParseMagicHeader(t *testing.T) {
tests := []struct {
name string
key string
value string
expected MagicHeader
errorMsg string
}{
{
name: "single value",
key: "header1",
value: "100",
expected: MagicHeader{Min: 100, Max: 100},
},
{
name: "valid range",
key: "header2",
value: "10-20",
expected: MagicHeader{Min: 10, Max: 20},
},
{
name: "zero single value",
key: "header3",
value: "0",
expected: MagicHeader{Min: 0, Max: 0},
},
{
name: "zero range",
key: "header4",
value: "0-0",
expected: MagicHeader{Min: 0, Max: 0},
},
{
name: "max uint32 single",
key: "header5",
value: "4294967295",
expected: MagicHeader{Min: 4294967295, Max: 4294967295},
},
{
name: "max uint32 range",
key: "header6",
value: "4294967294-4294967295",
expected: MagicHeader{Min: 4294967294, Max: 4294967295},
},
{
name: "invalid single value - not number",
key: "header7",
value: "abc",
expected: MagicHeader{},
errorMsg: "parse key: header7; value: abc;",
},
{
name: "invalid single value - negative",
key: "header8",
value: "-5",
expected: MagicHeader{},
errorMsg: "invalid value for key: header8; value: -5;",
},
{
name: "invalid single value - too large",
key: "header9",
value: "4294967296",
expected: MagicHeader{},
errorMsg: "parse key: header9; value: 4294967296;",
},
{
name: "invalid range - min not number",
key: "header10",
value: "abc-10",
expected: MagicHeader{},
errorMsg: "parse min key: header10; value: abc;",
},
{
name: "invalid range - max not number",
key: "header11",
value: "10-abc",
expected: MagicHeader{},
errorMsg: "parse max key: header11; value: abc;",
},
{
name: "invalid range - min greater than max",
key: "header12",
value: "20-10",
expected: MagicHeader{},
errorMsg: "new magicHeader key: header12; value: 20-10;",
},
{
name: "invalid range - too many parts",
key: "header13",
value: "10-20-30",
expected: MagicHeader{},
errorMsg: "parse key: header13; value: 10-20-30;",
},
{
name: "empty value",
key: "header14",
value: "",
expected: MagicHeader{},
errorMsg: "parse key: header14; value: ;",
},
{
name: "hyphen only",
key: "header15",
value: "-",
expected: MagicHeader{},
errorMsg: "invalid value for key: header15; value: -;",
},
{
name: "empty min",
key: "header16",
value: "-10",
expected: MagicHeader{},
errorMsg: "invalid value for key: header16; value: -10;",
},
{
name: "empty max",
key: "header17",
value: "10-",
expected: MagicHeader{},
errorMsg: "invalid value for key: header17; value: 10-;",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result, err := ParseMagicHeader(tt.key, tt.value)
if tt.errorMsg != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errorMsg)
require.Equal(t, MagicHeader{}, result)
} else {
require.NoError(t, err)
require.Equal(t, tt.expected, result)
}
})
}
}
func TestNewMagicHeaders(t *testing.T) {
tests := []struct {
name string
magicHeaders []MagicHeader
errorMsg string
}{
{
name: "valid non-overlapping headers",
magicHeaders: []MagicHeader{
{Min: 1, Max: 10},
{Min: 11, Max: 20},
{Min: 21, Max: 30},
{Min: 31, Max: 40},
},
},
{
name: "valid adjacent headers",
magicHeaders: []MagicHeader{
{Min: 1, Max: 1},
{Min: 2, Max: 2},
{Min: 3, Max: 3},
{Min: 4, Max: 4},
},
},
{
name: "valid zero-based headers",
magicHeaders: []MagicHeader{
{Min: 0, Max: 0},
{Min: 1, Max: 1},
{Min: 2, Max: 2},
{Min: 3, Max: 3},
},
},
{
name: "valid large value headers",
magicHeaders: []MagicHeader{
{Min: 4294967290, Max: 4294967291},
{Min: 4294967292, Max: 4294967293},
{Min: 4294967294, Max: 4294967294},
{Min: 4294967295, Max: 4294967295},
},
},
{
name: "too few headers",
magicHeaders: []MagicHeader{
{Min: 1, Max: 10},
{Min: 11, Max: 20},
{Min: 21, Max: 30},
},
errorMsg: "all header types should be included:",
},
{
name: "too many headers",
magicHeaders: []MagicHeader{
{Min: 1, Max: 10},
{Min: 11, Max: 20},
{Min: 21, Max: 30},
{Min: 31, Max: 40},
{Min: 41, Max: 50},
},
errorMsg: "all header types should be included:",
},
{
name: "empty headers",
magicHeaders: []MagicHeader{},
errorMsg: "all header types should be included:",
},
{
name: "overlapping headers",
magicHeaders: []MagicHeader{
{Min: 1, Max: 15},
{Min: 10, Max: 20},
{Min: 25, Max: 30},
{Min: 35, Max: 40},
},
errorMsg: "magic headers shouldn't overlap;",
},
{
name: "overlapping headers at limit-first",
magicHeaders: []MagicHeader{
{Min: 1, Max: 10},
{Min: 10, Max: 20},
{Min: 25, Max: 30},
{Min: 35, Max: 40},
},
errorMsg: "magic headers shouldn't overlap;",
},
{
name: "overlapping headers at limit-second",
magicHeaders: []MagicHeader{
{Min: 1, Max: 10},
{Min: 15, Max: 25},
{Min: 25, Max: 30},
{Min: 35, Max: 40},
},
errorMsg: "magic headers shouldn't overlap;",
},
{
name: "overlapping headers at limit-third",
magicHeaders: []MagicHeader{
{Min: 1, Max: 10},
{Min: 15, Max: 25},
{Min: 30, Max: 35},
{Min: 35, Max: 40},
},
errorMsg: "magic headers shouldn't overlap;",
},
{
name: "identical ranges",
magicHeaders: []MagicHeader{
{Min: 10, Max: 20},
{Min: 10, Max: 20},
{Min: 25, Max: 30},
{Min: 35, Max: 40},
},
errorMsg: "magic headers shouldn't overlap;",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
result, err := NewMagicHeaders(tt.magicHeaders)
if tt.errorMsg != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errorMsg)
require.Equal(t, MagicHeaders{}, result)
} else {
require.NoError(t, err)
require.Equal(t, tt.magicHeaders, result.Values)
require.NotNil(t, result.randomGenerator)
}
})
}
}
// Mock PRNG for testing
type mockPRNG struct {
returnValue uint32
}
func (m *mockPRNG) RandomSizeInRange(min, max uint32) uint32 {
return m.returnValue
}
func (m *mockPRNG) Get() uint64 {
return 0
}
func (m *mockPRNG) ReadSize(size int) []byte {
return make([]byte, size)
}
func TestMagicHeaders_Get(t *testing.T) {
// Create test headers
headers := []MagicHeader{
{Min: 1, Max: 10},
{Min: 11, Max: 20},
{Min: 21, Max: 30},
{Min: 31, Max: 40},
}
tests := []struct {
name string
defaultMsgType uint32
mockValue uint32
expectedValue uint32
errorMsg string
}{
{
name: "valid type 1",
defaultMsgType: 1,
mockValue: 5,
expectedValue: 5,
},
{
name: "valid type 2",
defaultMsgType: 2,
mockValue: 15,
expectedValue: 15,
},
{
name: "valid type 3",
defaultMsgType: 3,
mockValue: 25,
expectedValue: 25,
},
{
name: "valid type 4",
defaultMsgType: 4,
mockValue: 35,
expectedValue: 35,
},
{
name: "invalid type 0",
defaultMsgType: 0,
mockValue: 0,
expectedValue: 0,
errorMsg: "invalid msg type: 0",
},
{
name: "invalid type 5",
defaultMsgType: 5,
mockValue: 0,
expectedValue: 0,
errorMsg: "invalid msg type: 5",
},
{
name: "invalid type max uint32",
defaultMsgType: 4294967295,
mockValue: 0,
expectedValue: 0,
errorMsg: "invalid msg type: 4294967295",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Create a new instance with mock PRNG for each test
testMagicHeaders := MagicHeaders{
Values: headers,
randomGenerator: &mockPRNG{returnValue: tt.mockValue},
}
result, err := testMagicHeaders.Get(tt.defaultMsgType)
if tt.errorMsg != "" {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errorMsg)
require.Equal(t, uint32(0), result)
} else {
require.NoError(t, err)
require.Equal(t, tt.expectedValue, result)
}
})
}
}

View File

@@ -1,50 +0,0 @@
package awg
import (
crand "crypto/rand"
v2 "math/rand/v2"
"golang.org/x/exp/constraints"
)
type RandomNumberGenerator[T constraints.Integer] interface {
RandomSizeInRange(min, max T) T
Get() uint64
ReadSize(size int) []byte
}
type PRNG[T constraints.Integer] struct {
cha8Rand *v2.ChaCha8
}
func NewPRNG[T constraints.Integer]() PRNG[T] {
buf := make([]byte, 32)
_, _ = crand.Read(buf)
return PRNG[T]{
cha8Rand: v2.NewChaCha8([32]byte(buf)),
}
}
func (p PRNG[T]) RandomSizeInRange(min, max T) T {
if min > max {
panic("min must be less than max")
}
if min == max {
return min
}
return T(p.Get()%uint64(max-min)) + min
}
func (p PRNG[T]) Get() uint64 {
return p.cha8Rand.Uint64()
}
func (p PRNG[T]) ReadSize(size int) []byte {
// TODO: use a memory pool to allocate
buf := make([]byte, size)
_, _ = p.cha8Rand.Read(buf)
return buf
}

View File

@@ -1,36 +0,0 @@
package awg
import (
"github.com/tevino/abool"
"go.uber.org/atomic"
)
// TODO: atomic?/ and better way to use this
var PacketCounter *atomic.Uint64 = atomic.NewUint64(0)
// TODO
var WaitResponse = struct {
Channel chan struct{}
ShouldWait *abool.AtomicBool
}{
make(chan struct{}, 1),
abool.New(),
}
type SpecialHandshakeHandler struct {
SpecialJunk TagJunkPacketGenerators
IsSet bool
}
func (handler *SpecialHandshakeHandler) Validate() error {
return handler.SpecialJunk.Validate()
}
func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
if !handler.SpecialJunk.IsDefined() {
return nil
}
return handler.SpecialJunk.GeneratePackets()
}

View File

@@ -1,229 +0,0 @@
package awg
import (
crand "crypto/rand"
"encoding/binary"
"encoding/hex"
"fmt"
"strconv"
"strings"
"time"
v2 "math/rand/v2"
// "go.uber.org/atomic"
)
type Generator interface {
Generate() []byte
Size() int
}
type newGenerator func(string) (Generator, error)
type BytesGenerator struct {
value []byte
size int
}
func (bg *BytesGenerator) Generate() []byte {
return bg.value
}
func (bg *BytesGenerator) Size() int {
return bg.size
}
func newBytesGenerator(param string) (Generator, error) {
hasPrefix := strings.HasPrefix(param, "0x") || strings.HasPrefix(param, "0X")
if !hasPrefix {
return nil, fmt.Errorf("not correct hex: %s", param)
}
hex, err := hexToBytes(param)
if err != nil {
return nil, fmt.Errorf("hexToBytes: %w", err)
}
return &BytesGenerator{value: hex, size: len(hex)}, nil
}
func hexToBytes(hexStr string) ([]byte, error) {
hexStr = strings.TrimPrefix(hexStr, "0x")
hexStr = strings.TrimPrefix(hexStr, "0X")
// Ensure even length (pad with leading zero if needed)
if len(hexStr)%2 != 0 {
hexStr = "0" + hexStr
}
return hex.DecodeString(hexStr)
}
type randomGeneratorBase struct {
cha8Rand *v2.ChaCha8
size int
}
func newRandomGeneratorBase(param string) (*randomGeneratorBase, error) {
size, err := strconv.Atoi(param)
if err != nil {
return nil, fmt.Errorf("parse int: %w", err)
}
if size > 1000 {
return nil, fmt.Errorf("size must be less than 1000")
}
buf := make([]byte, 32)
_, err = crand.Read(buf)
if err != nil {
return nil, fmt.Errorf("crand read: %w", err)
}
return &randomGeneratorBase{
cha8Rand: v2.NewChaCha8([32]byte(buf)),
size: size,
}, nil
}
func (rpg *randomGeneratorBase) generate() []byte {
junk := make([]byte, rpg.size)
rpg.cha8Rand.Read(junk)
return junk
}
func (rpg *randomGeneratorBase) Size() int {
return rpg.size
}
type RandomBytesGenerator struct {
*randomGeneratorBase
}
func newRandomBytesGenerator(param string) (Generator, error) {
rpgBase, err := newRandomGeneratorBase(param)
if err != nil {
return nil, fmt.Errorf("new random bytes generator: %w", err)
}
return &RandomBytesGenerator{randomGeneratorBase: rpgBase}, nil
}
func (rpg *RandomBytesGenerator) Generate() []byte {
return rpg.generate()
}
const alphanumericChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
type RandomASCIIGenerator struct {
*randomGeneratorBase
}
func newRandomASCIIGenerator(param string) (Generator, error) {
rpgBase, err := newRandomGeneratorBase(param)
if err != nil {
return nil, fmt.Errorf("new random ascii generator: %w", err)
}
return &RandomASCIIGenerator{randomGeneratorBase: rpgBase}, nil
}
func (rpg *RandomASCIIGenerator) Generate() []byte {
junk := rpg.generate()
result := make([]byte, rpg.size)
for i, b := range junk {
result[i] = alphanumericChars[b%byte(len(alphanumericChars))]
}
return result
}
type RandomDigitGenerator struct {
*randomGeneratorBase
}
func newRandomDigitGenerator(param string) (Generator, error) {
rpgBase, err := newRandomGeneratorBase(param)
if err != nil {
return nil, fmt.Errorf("new random digit generator: %w", err)
}
return &RandomDigitGenerator{randomGeneratorBase: rpgBase}, nil
}
func (rpg *RandomDigitGenerator) Generate() []byte {
junk := rpg.generate()
result := make([]byte, rpg.size)
for i, b := range junk {
result[i] = '0' + (b % 10) // Convert to digit character
}
return result
}
type TimestampGenerator struct {
}
func (tg *TimestampGenerator) Generate() []byte {
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
return buf
}
func (tg *TimestampGenerator) Size() int {
return 8
}
func newTimestampGenerator(param string) (Generator, error) {
if len(param) != 0 {
return nil, fmt.Errorf("timestamp param needs to be empty: %s", param)
}
return &TimestampGenerator{}, nil
}
type PacketCounterGenerator struct {
}
func (c *PacketCounterGenerator) Generate() []byte {
buf := make([]byte, 8)
// TODO: better way to handle counter tag
binary.BigEndian.PutUint64(buf, PacketCounter.Load())
return buf
}
func (c *PacketCounterGenerator) Size() int {
return 8
}
func newPacketCounterGenerator(param string) (Generator, error) {
if len(param) != 0 {
return nil, fmt.Errorf("packet counter param needs to be empty: %s", param)
}
return &PacketCounterGenerator{}, nil
}
type WaitResponseGenerator struct {
}
func (c *WaitResponseGenerator) Generate() []byte {
WaitResponse.ShouldWait.Set()
<-WaitResponse.Channel
WaitResponse.ShouldWait.UnSet()
return []byte{}
}
func (c *WaitResponseGenerator) Size() int {
return 0
}
func newWaitResponseGenerator(param string) (Generator, error) {
if len(param) != 0 {
return nil, fmt.Errorf("wait response param needs to be empty: %s", param)
}
return &WaitResponseGenerator{}, nil
}

View File

@@ -1,321 +0,0 @@
package awg
import (
"encoding/binary"
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestNewBytesGenerator(t *testing.T) {
t.Parallel()
type args struct {
param string
}
tests := []struct {
name string
args args
want []byte
wantErr error
}{
{
name: "empty",
args: args{
param: "",
},
wantErr: fmt.Errorf("not correct hex"),
},
{
name: "wrong start",
args: args{
param: "123456",
},
wantErr: fmt.Errorf("not correct hex"),
},
{
name: "not only hex value with X",
args: args{
param: "0X12345q",
},
wantErr: fmt.Errorf("not correct hex"),
},
{
name: "not only hex value with x",
args: args{
param: "0x12345q",
},
wantErr: fmt.Errorf("not correct hex"),
},
{
name: "valid hex",
args: args{
param: "0xf6ab3267fa",
},
want: []byte{0xf6, 0xab, 0x32, 0x67, 0xfa},
},
{
name: "valid hex with odd length",
args: args{
param: "0xfab3267fa",
},
want: []byte{0xf, 0xab, 0x32, 0x67, 0xfa},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := newBytesGenerator(tt.args.param)
if tt.wantErr != nil {
require.ErrorAs(t, err, &tt.wantErr)
require.Nil(t, got)
return
}
require.Nil(t, err)
require.NotNil(t, got)
gotValues := got.Generate()
require.Equal(t, tt.want, gotValues)
})
}
}
func TestNewRandomBytesGenerator(t *testing.T) {
t.Parallel()
type args struct {
param string
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "empty",
args: args{
param: "",
},
wantErr: fmt.Errorf("parse int"),
},
{
name: "not an int",
args: args{
param: "x",
},
wantErr: fmt.Errorf("parse int"),
},
{
name: "too large",
args: args{
param: "1001",
},
wantErr: fmt.Errorf("random packet size must be less than 1000"),
},
{
name: "valid",
args: args{
param: "12",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := newRandomBytesGenerator(tt.args.param)
if tt.wantErr != nil {
require.ErrorAs(t, err, &tt.wantErr)
require.Nil(t, got)
return
}
require.Nil(t, err)
require.NotNil(t, got)
first := got.Generate()
second := got.Generate()
require.NotEqual(t, first, second)
})
}
}
func TestNewRandomASCIIGenerator(t *testing.T) {
t.Parallel()
type args struct {
param string
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "empty",
args: args{
param: "",
},
wantErr: fmt.Errorf("parse int"),
},
{
name: "not an int",
args: args{
param: "x",
},
wantErr: fmt.Errorf("parse int"),
},
{
name: "too large",
args: args{
param: "1001",
},
wantErr: fmt.Errorf("random packet size must be less than 1000"),
},
{
name: "valid",
args: args{
param: "12",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := newRandomASCIIGenerator(tt.args.param)
if tt.wantErr != nil {
require.ErrorAs(t, err, &tt.wantErr)
require.Nil(t, got)
return
}
require.Nil(t, err)
require.NotNil(t, got)
first := got.Generate()
second := got.Generate()
require.NotEqual(t, first, second)
})
}
}
func TestNewRandomDigitGenerator(t *testing.T) {
t.Parallel()
type args struct {
param string
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "empty",
args: args{
param: "",
},
wantErr: fmt.Errorf("parse int"),
},
{
name: "not an int",
args: args{
param: "x",
},
wantErr: fmt.Errorf("parse int"),
},
{
name: "too large",
args: args{
param: "1001",
},
wantErr: fmt.Errorf("random packet size must be less than 1000"),
},
{
name: "valid",
args: args{
param: "12",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := newRandomDigitGenerator(tt.args.param)
if tt.wantErr != nil {
require.ErrorAs(t, err, &tt.wantErr)
require.Nil(t, got)
return
}
require.Nil(t, err)
require.NotNil(t, got)
first := got.Generate()
second := got.Generate()
require.NotEqual(t, first, second)
})
}
}
func TestPacketCounterGenerator(t *testing.T) {
t.Parallel()
tests := []struct {
name string
param string
wantErr bool
}{
{
name: "Valid empty param",
param: "",
wantErr: false,
},
{
name: "Invalid non-empty param",
param: "anything",
wantErr: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
gen, err := newPacketCounterGenerator(tc.param)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, 8, gen.Size())
// Reset counter to known value for test
initialCount := uint64(42)
PacketCounter.Store(initialCount)
output := gen.Generate()
require.Equal(t, 8, len(output))
// Verify counter value in output
counterValue := binary.BigEndian.Uint64(output)
require.Equal(t, initialCount, counterValue)
// Increment counter and verify change
PacketCounter.Add(1)
output = gen.Generate()
counterValue = binary.BigEndian.Uint64(output)
require.Equal(t, initialCount+1, counterValue)
})
}
}

View File

@@ -1,59 +0,0 @@
package awg
import (
"fmt"
"strconv"
)
type TagJunkPacketGenerator struct {
name string
tagValue string
packetSize int
generators []Generator
}
func newTagJunkPacketGenerator(name, tagValue string, size int) TagJunkPacketGenerator {
return TagJunkPacketGenerator{
name: name,
tagValue: tagValue,
generators: make([]Generator, 0, size),
}
}
func (tg *TagJunkPacketGenerator) append(generator Generator) {
tg.generators = append(tg.generators, generator)
tg.packetSize += generator.Size()
}
func (tg *TagJunkPacketGenerator) generatePacket() []byte {
packet := make([]byte, 0, tg.packetSize)
for _, generator := range tg.generators {
packet = append(packet, generator.Generate()...)
}
return packet
}
func (tg *TagJunkPacketGenerator) Name() string {
return tg.name
}
func (tg *TagJunkPacketGenerator) nameIndex() (int, error) {
if len(tg.name) != 2 {
return 0, fmt.Errorf("name must be 2 character long: %s", tg.name)
}
index, err := strconv.Atoi(tg.name[1:2])
if err != nil {
return 0, fmt.Errorf("name 2 char should be an int %w", err)
}
return index, nil
}
func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields {
return IpcFields{
Key: tg.name,
Value: tg.tagValue,
}
}

View File

@@ -1,210 +0,0 @@
package awg
import (
"testing"
"github.com/amnezia-vpn/amneziawg-go/device/awg/internal"
"github.com/stretchr/testify/require"
)
func TestNewTagJunkGenerator(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
genName string
size int
expected TagJunkPacketGenerator
}{
{
name: "Create new generator with empty name",
genName: "",
size: 0,
expected: TagJunkPacketGenerator{
name: "",
packetSize: 0,
generators: make([]Generator, 0),
},
},
{
name: "Create new generator with valid name",
genName: "T1",
size: 0,
expected: TagJunkPacketGenerator{
name: "T1",
packetSize: 0,
generators: make([]Generator, 0),
},
},
{
name: "Create new generator with non-zero size",
genName: "T2",
size: 5,
expected: TagJunkPacketGenerator{
name: "T2",
packetSize: 0,
generators: make([]Generator, 5),
},
},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
result := newTagJunkPacketGenerator(tc.genName, "", tc.size)
require.Equal(t, tc.expected.name, result.name)
require.Equal(t, tc.expected.packetSize, result.packetSize)
require.Equal(t, cap(result.generators), len(tc.expected.generators))
})
}
}
func TestTagJunkGeneratorAppend(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
initialState TagJunkPacketGenerator
mockSize int
expectedLength int
expectedSize int
}{
{
name: "Append to empty generator",
initialState: newTagJunkPacketGenerator("T1", "", 0),
mockSize: 5,
expectedLength: 1,
expectedSize: 5,
},
{
name: "Append to non-empty generator",
initialState: TagJunkPacketGenerator{
name: "T2",
packetSize: 10,
generators: make([]Generator, 2),
},
mockSize: 7,
expectedLength: 3, // 2 existing + 1 new
expectedSize: 17, // 10 + 7
},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tg := tc.initialState
mockGen := internal.NewMockGenerator(tc.mockSize)
tg.append(mockGen)
require.Equal(t, tc.expectedLength, len(tg.generators))
require.Equal(t, tc.expectedSize, tg.packetSize)
})
}
}
func TestTagJunkGeneratorGenerate(t *testing.T) {
t.Parallel()
// Create mock generators for testing
mockGen1 := internal.NewMockByteGenerator([]byte{0x01, 0x02})
mockGen2 := internal.NewMockByteGenerator([]byte{0x03, 0x04, 0x05})
testCases := []struct {
name string
setupGenerator func() TagJunkPacketGenerator
expected []byte
}{
{
name: "Generate with empty generators",
setupGenerator: func() TagJunkPacketGenerator {
return newTagJunkPacketGenerator("T1", "", 0)
},
expected: []byte{},
},
{
name: "Generate with single generator",
setupGenerator: func() TagJunkPacketGenerator {
tg := newTagJunkPacketGenerator("T2", "", 0)
tg.append(mockGen1)
return tg
},
expected: []byte{0x01, 0x02},
},
{
name: "Generate with multiple generators",
setupGenerator: func() TagJunkPacketGenerator {
tg := newTagJunkPacketGenerator("T3", "", 0)
tg.append(mockGen1)
tg.append(mockGen2)
return tg
},
expected: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tg := tc.setupGenerator()
result := tg.generatePacket()
require.Equal(t, tc.expected, result)
})
}
}
func TestTagJunkGeneratorNameIndex(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
generatorName string
expectedIndex int
expectError bool
}{
{
name: "Valid name with digit",
generatorName: "T5",
expectedIndex: 5,
expectError: false,
},
{
name: "Invalid name - too short",
generatorName: "T",
expectError: true,
},
{
name: "Invalid name - too long",
generatorName: "T55",
expectError: true,
},
{
name: "Invalid name - non-digit second character",
generatorName: "TX",
expectError: true,
},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
tg := TagJunkPacketGenerator{name: tc.generatorName}
index, err := tg.nameIndex()
if tc.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tc.expectedIndex, index)
}
})
}
}

View File

@@ -1,66 +0,0 @@
package awg
import "fmt"
type TagJunkPacketGenerators struct {
tagGenerators []TagJunkPacketGenerator
length int
DefaultJunkCount int // Jc
}
func (generators *TagJunkPacketGenerators) AppendGenerator(
generator TagJunkPacketGenerator,
) {
generators.tagGenerators = append(generators.tagGenerators, generator)
generators.length++
}
func (generators *TagJunkPacketGenerators) IsDefined() bool {
return len(generators.tagGenerators) > 0
}
// validate that packets were defined consecutively
func (generators *TagJunkPacketGenerators) Validate() error {
seen := make([]bool, len(generators.tagGenerators))
for _, generator := range generators.tagGenerators {
index, err := generator.nameIndex()
if index > len(generators.tagGenerators) {
return fmt.Errorf("junk packet index should be consecutive")
}
if err != nil {
return fmt.Errorf("name index: %w", err)
} else {
seen[index-1] = true
}
}
for _, found := range seen {
if !found {
return fmt.Errorf("junk packet index should be consecutive")
}
}
return nil
}
func (generators *TagJunkPacketGenerators) GeneratePackets() [][]byte {
var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount)
for i, tagGenerator := range generators.tagGenerators {
rv = append(rv, make([]byte, tagGenerator.packetSize))
copy(rv[i], tagGenerator.generatePacket())
PacketCounter.Inc()
}
PacketCounter.Add(uint64(generators.DefaultJunkCount))
return rv
}
func (tg *TagJunkPacketGenerators) IpcGetFields() []IpcFields {
rv := make([]IpcFields, 0, len(tg.tagGenerators))
for _, generator := range tg.tagGenerators {
rv = append(rv, generator.IpcGetFields())
}
return rv
}

View File

@@ -1,149 +0,0 @@
package awg
import (
"testing"
"github.com/amnezia-vpn/amneziawg-go/device/awg/internal"
"github.com/stretchr/testify/require"
)
func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) {
tests := []struct {
name string
generator TagJunkPacketGenerator
}{
{
name: "append single generator",
generator: newTagJunkPacketGenerator("t1", "", 10),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
generators := &TagJunkPacketGenerators{}
// Initial length should be 0
require.Equal(t, 0, generators.length)
require.Empty(t, generators.tagGenerators)
// After append, length should be 1 and generator should be added
generators.AppendGenerator(tt.generator)
require.Equal(t, 1, generators.length)
require.Len(t, generators.tagGenerators, 1)
require.Equal(t, tt.generator, generators.tagGenerators[0])
})
}
}
func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
tests := []struct {
name string
generators []TagJunkPacketGenerator
wantErr bool
errMsg string
}{
{
name: "bad start",
generators: []TagJunkPacketGenerator{
newTagJunkPacketGenerator("t3", "", 10),
newTagJunkPacketGenerator("t4", "", 10),
},
wantErr: true,
errMsg: "junk packet index should be consecutive",
},
{
name: "non-consecutive indices",
generators: []TagJunkPacketGenerator{
newTagJunkPacketGenerator("t1", "", 10),
newTagJunkPacketGenerator("t3", "", 10), // Missing t2
},
wantErr: true,
errMsg: "junk packet index should be consecutive",
},
{
name: "consecutive indices",
generators: []TagJunkPacketGenerator{
newTagJunkPacketGenerator("t1", "", 10),
newTagJunkPacketGenerator("t2", "", 10),
newTagJunkPacketGenerator("t3", "", 10),
newTagJunkPacketGenerator("t4", "", 10),
newTagJunkPacketGenerator("t5", "", 10),
},
},
{
name: "nameIndex error",
generators: []TagJunkPacketGenerator{
newTagJunkPacketGenerator("error", "", 10),
},
wantErr: true,
errMsg: "name must be 2 character long",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
generators := &TagJunkPacketGenerators{}
for _, gen := range tt.generators {
generators.AppendGenerator(gen)
}
err := generators.Validate()
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errMsg)
return
}
require.NoError(t, err)
})
}
}
func TestTagJunkGeneratorHandlerGenerate(t *testing.T) {
mockByte1 := []byte{0x01, 0x02}
mockByte2 := []byte{0x03, 0x04, 0x05}
mockGen1 := internal.NewMockByteGenerator(mockByte1)
mockGen2 := internal.NewMockByteGenerator(mockByte2)
tests := []struct {
name string
setupGenerator func() []TagJunkPacketGenerator
expected [][]byte
}{
{
name: "generate with no default junk",
setupGenerator: func() []TagJunkPacketGenerator {
tg1 := newTagJunkPacketGenerator("t1", "", 0)
tg1.append(mockGen1)
tg1.append(mockGen2)
tg2 := newTagJunkPacketGenerator("t2", "", 0)
tg2.append(mockGen2)
tg2.append(mockGen1)
return []TagJunkPacketGenerator{tg1, tg2}
},
expected: [][]byte{
append(mockByte1, mockByte2...),
append(mockByte2, mockByte1...),
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
generators := &TagJunkPacketGenerators{}
tagGenerators := tt.setupGenerator()
for _, gen := range tagGenerators {
generators.AppendGenerator(gen)
}
result := generators.GeneratePackets()
require.Equal(t, result, tt.expected)
})
}
}

View File

@@ -1,112 +0,0 @@
package awg
import (
"fmt"
"maps"
"regexp"
"strings"
)
type IpcFields struct{ Key, Value string }
type EnumTag string
const (
BytesEnumTag EnumTag = "b"
CounterEnumTag EnumTag = "c"
TimestampEnumTag EnumTag = "t"
RandomBytesEnumTag EnumTag = "r"
RandomASCIIEnumTag EnumTag = "rc"
RandomDigitEnumTag EnumTag = "rd"
)
var generatorCreator = map[EnumTag]newGenerator{
BytesEnumTag: newBytesGenerator,
CounterEnumTag: newPacketCounterGenerator,
TimestampEnumTag: newTimestampGenerator,
RandomBytesEnumTag: newRandomBytesGenerator,
RandomASCIIEnumTag: newRandomASCIIGenerator,
RandomDigitEnumTag: newRandomDigitGenerator,
}
// helper map to determine enumTags are unique
var uniqueTags = map[EnumTag]bool{
CounterEnumTag: false,
TimestampEnumTag: false,
}
type Tag struct {
Name EnumTag
Param string
}
func parseTag(input string) (Tag, error) {
// Regular expression to match <tagname optional_param>
re := regexp.MustCompile(`([a-zA-Z]+)(?:\s+([^>]+))?>`)
match := re.FindStringSubmatch(input)
tag := Tag{
Name: EnumTag(match[1]),
}
if len(match) > 2 && match[2] != "" {
tag.Param = strings.TrimSpace(match[2])
}
return tag, nil
}
func ParseTagJunkGenerator(name, input string) (TagJunkPacketGenerator, error) {
inputSlice := strings.Split(input, "<")
if len(inputSlice) <= 1 {
return TagJunkPacketGenerator{}, fmt.Errorf("empty input: %s", input)
}
uniqueTagCheck := make(map[EnumTag]bool, len(uniqueTags))
maps.Copy(uniqueTagCheck, uniqueTags)
// skip byproduct of split
inputSlice = inputSlice[1:]
rv := newTagJunkPacketGenerator(name, input, len(inputSlice))
for _, inputParam := range inputSlice {
if len(inputParam) <= 1 {
return TagJunkPacketGenerator{}, fmt.Errorf(
"empty tag in input: %s",
inputSlice,
)
} else if strings.Count(inputParam, ">") != 1 {
return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input)
}
tag, _ := parseTag(inputParam)
creator, ok := generatorCreator[tag.Name]
if !ok {
return TagJunkPacketGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name)
}
if present, ok := uniqueTagCheck[tag.Name]; ok {
if present {
return TagJunkPacketGenerator{}, fmt.Errorf(
"tag %s needs to be unique",
tag.Name,
)
}
uniqueTagCheck[tag.Name] = true
}
generator, err := creator(tag.Param)
if err != nil {
return TagJunkPacketGenerator{}, fmt.Errorf("gen: %w", err)
}
// TODO: handle counter tag
// if tag.Name == CounterEnumTag {
// packetCounter, ok := generator.(*PacketCounterGenerator)
// if !ok {
// log.Fatalf("packet counter generator expected, got %T", generator)
// }
// PacketCounter = packetCounter.counter
// }
rv.append(generator)
}
return rv, nil
}

View File

@@ -1,77 +0,0 @@
package awg
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestParse(t *testing.T) {
type args struct {
name string
input string
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "invalid name",
args: args{name: "apple", input: ""},
wantErr: fmt.Errorf("ill formated input"),
},
{
name: "empty",
args: args{name: "i1", input: ""},
wantErr: fmt.Errorf("ill formated input"),
},
{
name: "extra >",
args: args{name: "i1", input: "<b 0xf6ab3267fa><c>>"},
wantErr: fmt.Errorf("ill formated input"),
},
{
name: "extra <",
args: args{name: "i1", input: "<<b 0xf6ab3267fa><c>"},
wantErr: fmt.Errorf("empty tag in input"),
},
{
name: "empty <>",
args: args{name: "i1", input: "<><b 0xf6ab3267fa><c>"},
wantErr: fmt.Errorf("empty tag in input"),
},
{
name: "invalid tag",
args: args{name: "i1", input: "<q 0xf6ab3267fa>"},
wantErr: fmt.Errorf("invalid tag"),
},
{
name: "counter uniqueness violation",
args: args{name: "i1", input: "<c><c>"},
wantErr: fmt.Errorf("parse tag needs to be unique"),
},
{
name: "timestamp uniqueness violation",
args: args{name: "i1", input: "<t><t>"},
wantErr: fmt.Errorf("parse tag needs to be unique"),
},
{
name: "valid",
args: args{input: "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := ParseTagJunkGenerator(tt.args.name, tt.args.input)
// TODO: ErrorAs doesn't work as you think
if tt.wantErr != nil {
require.ErrorAs(t, err, &tt.wantErr)
return
}
require.Nil(t, err)
})
}
}

View File

@@ -99,7 +99,7 @@ func TestCookieMAC1(t *testing.T) {
0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d,
}
generator.AddMacs(msg)
reply, err := checker.CreateReply(msg, 1377, src, DefaultMessageCookieReplyType)
reply, err := checker.CreateReply(msg, 1377, src, MessageCookieReplyType)
if err != nil {
t.Fatal("Failed to create cookie reply:", err)
}

View File

@@ -6,57 +6,17 @@
package device
import (
"encoding/binary"
"errors"
"fmt"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device/awg"
"github.com/amnezia-vpn/amneziawg-go/ipc"
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
"github.com/amnezia-vpn/amneziawg-go/tun"
)
type Version uint8
const (
VersionDefault Version = iota
VersionAwg
VersionAwgSpecialHandshake
)
// TODO:
type AtomicVersion struct {
value atomic.Uint32
}
func NewAtomicVersion(v Version) *AtomicVersion {
av := &AtomicVersion{}
av.Store(v)
return av
}
func (av *AtomicVersion) Load() Version {
return Version(av.value.Load())
}
func (av *AtomicVersion) Store(v Version) {
av.value.Store(uint32(v))
}
func (av *AtomicVersion) CompareAndSwap(old, new Version) bool {
return av.value.CompareAndSwap(uint32(old), uint32(new))
}
func (av *AtomicVersion) Swap(new Version) Version {
return Version(av.value.Swap(uint32(new)))
}
type Device struct {
state struct {
// state holds the device's state. It is accessed atomically.
@@ -130,8 +90,27 @@ type Device struct {
closed chan struct{}
log *Logger
version Version
awg awg.Protocol
junk struct {
min int
max int
count int
}
headers struct {
init *magicHeader
cookie *magicHeader
response *magicHeader
transport *magicHeader
}
paddings struct {
init int
response int
cookie int
transport int
}
ipackets [5]*obfChain
}
// deviceState represents the state of a Device.
@@ -342,6 +321,11 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device.rate.limiter.Init()
device.indexTable.Init()
device.headers.init = &magicHeader{start: MessageInitiationType, end: MessageInitiationType}
device.headers.response = &magicHeader{start: MessageResponseType, end: MessageResponseType}
device.headers.cookie = &magicHeader{start: MessageCookieReplyType, end: MessageCookieReplyType}
device.headers.transport = &magicHeader{start: MessageTransportType, end: MessageTransportType}
device.PopulatePools()
// create queues
@@ -439,8 +423,6 @@ func (device *Device) Close() {
device.rate.limiter.Close()
device.resetProtocol()
device.log.Verbosef("Device closed")
close(device.closed)
}
@@ -580,358 +562,3 @@ func (device *Device) BindClose() error {
device.net.Unlock()
return err
}
func (device *Device) isAWG() bool {
return device.version >= VersionAwg
}
func (device *Device) resetProtocol() {
// restore default message type values
MessageInitiationType = DefaultMessageInitiationType
MessageResponseType = DefaultMessageResponseType
MessageCookieReplyType = DefaultMessageCookieReplyType
MessageTransportType = DefaultMessageTransportType
}
func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
if !tempAwg.Cfg.IsSet && !tempAwg.HandshakeHandler.IsSet {
return nil
}
var errs []error
isAwgOn := false
device.awg.Mux.Lock()
if tempAwg.Cfg.JunkPacketCount < 0 {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketCount should be non negative",
),
)
}
device.awg.Cfg.JunkPacketCount = tempAwg.Cfg.JunkPacketCount
if tempAwg.Cfg.JunkPacketCount != 0 {
isAwgOn = true
}
device.awg.Cfg.JunkPacketMinSize = tempAwg.Cfg.JunkPacketMinSize
if tempAwg.Cfg.JunkPacketMinSize != 0 {
isAwgOn = true
}
if device.awg.Cfg.JunkPacketCount > 0 &&
tempAwg.Cfg.JunkPacketMaxSize == tempAwg.Cfg.JunkPacketMinSize {
tempAwg.Cfg.JunkPacketMaxSize++ // to make rand gen work
}
if tempAwg.Cfg.JunkPacketMaxSize >= MaxSegmentSize {
device.awg.Cfg.JunkPacketMinSize = 0
device.awg.Cfg.JunkPacketMaxSize = 1
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
tempAwg.Cfg.JunkPacketMaxSize,
MaxSegmentSize,
))
} else if tempAwg.Cfg.JunkPacketMaxSize < tempAwg.Cfg.JunkPacketMinSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
"maxSize: %d; should be greater than minSize: %d",
tempAwg.Cfg.JunkPacketMaxSize,
tempAwg.Cfg.JunkPacketMinSize,
))
} else {
device.awg.Cfg.JunkPacketMaxSize = tempAwg.Cfg.JunkPacketMaxSize
}
if tempAwg.Cfg.JunkPacketMaxSize != 0 {
isAwgOn = true
}
magicHeaders := make([]awg.MagicHeader, 4)
if len(tempAwg.Cfg.MagicHeaders.Values) != 4 {
return ipcErrorf(
ipc.IpcErrorInvalid,
"magic headers should have 4 values; got: %d",
len(tempAwg.Cfg.MagicHeaders.Values),
)
}
if tempAwg.Cfg.MagicHeaders.Values[0].Min > 4 {
isAwgOn = true
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
magicHeaders[0] = tempAwg.Cfg.MagicHeaders.Values[0]
MessageInitiationType = magicHeaders[0].Min
} else {
device.log.Verbosef("UAPI: Using default init type")
MessageInitiationType = DefaultMessageInitiationType
magicHeaders[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
}
if tempAwg.Cfg.MagicHeaders.Values[1].Min > 4 {
isAwgOn = true
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
magicHeaders[1] = tempAwg.Cfg.MagicHeaders.Values[1]
MessageResponseType = magicHeaders[1].Min
} else {
device.log.Verbosef("UAPI: Using default response type")
MessageResponseType = DefaultMessageResponseType
magicHeaders[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
}
if tempAwg.Cfg.MagicHeaders.Values[2].Min > 4 {
isAwgOn = true
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
magicHeaders[2] = tempAwg.Cfg.MagicHeaders.Values[2]
MessageCookieReplyType = magicHeaders[2].Min
} else {
device.log.Verbosef("UAPI: Using default underload type")
MessageCookieReplyType = DefaultMessageCookieReplyType
magicHeaders[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
}
if tempAwg.Cfg.MagicHeaders.Values[3].Min > 4 {
isAwgOn = true
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
magicHeaders[3] = tempAwg.Cfg.MagicHeaders.Values[3]
MessageTransportType = magicHeaders[3].Min
} else {
device.log.Verbosef("UAPI: Using default transport type")
MessageTransportType = DefaultMessageTransportType
magicHeaders[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType)
}
var err error
device.awg.Cfg.MagicHeaders, err = awg.NewMagicHeaders(magicHeaders)
if err != nil {
errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err))
}
isSameHeaderMap := map[uint32]struct{}{
MessageInitiationType: {},
MessageResponseType: {},
MessageCookieReplyType: {},
MessageTransportType: {},
}
// size will be different if same values
if len(isSameHeaderMap) != 4 {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
MessageInitiationType,
MessageResponseType,
MessageCookieReplyType,
MessageTransportType,
),
)
}
newInitSize := MessageInitiationSize + tempAwg.Cfg.InitHeaderJunkSize
if newInitSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.Cfg.InitHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.Cfg.InitHeaderJunkSize = tempAwg.Cfg.InitHeaderJunkSize
}
if tempAwg.Cfg.InitHeaderJunkSize != 0 {
isAwgOn = true
}
newResponseSize := MessageResponseSize + tempAwg.Cfg.ResponseHeaderJunkSize
if newResponseSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.Cfg.ResponseHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.Cfg.ResponseHeaderJunkSize = tempAwg.Cfg.ResponseHeaderJunkSize
}
if tempAwg.Cfg.ResponseHeaderJunkSize != 0 {
isAwgOn = true
}
newCookieSize := MessageCookieReplySize + tempAwg.Cfg.CookieReplyHeaderJunkSize
if newCookieSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.Cfg.CookieReplyHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.Cfg.CookieReplyHeaderJunkSize = tempAwg.Cfg.CookieReplyHeaderJunkSize
}
if tempAwg.Cfg.CookieReplyHeaderJunkSize != 0 {
isAwgOn = true
}
newTransportSize := MessageTransportSize + tempAwg.Cfg.TransportHeaderJunkSize
if newTransportSize >= MaxSegmentSize {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
tempAwg.Cfg.TransportHeaderJunkSize,
MaxSegmentSize,
),
)
} else {
device.awg.Cfg.TransportHeaderJunkSize = tempAwg.Cfg.TransportHeaderJunkSize
}
if tempAwg.Cfg.TransportHeaderJunkSize != 0 {
isAwgOn = true
}
isSameSizeMap := map[int]struct{}{
newInitSize: {},
newResponseSize: {},
newCookieSize: {},
newTransportSize: {},
}
if len(isSameSizeMap) != 4 {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid,
`new sizes should differ; init: %d; response: %d; cookie: %d; trans: %d`,
newInitSize,
newResponseSize,
newCookieSize,
newTransportSize,
),
)
} else {
msgTypeToJunkSize = map[uint32]int{
MessageInitiationType: device.awg.Cfg.InitHeaderJunkSize,
MessageResponseType: device.awg.Cfg.ResponseHeaderJunkSize,
MessageCookieReplyType: device.awg.Cfg.CookieReplyHeaderJunkSize,
MessageTransportType: device.awg.Cfg.TransportHeaderJunkSize,
}
packetSizeToMsgType = map[int]uint32{
newInitSize: MessageInitiationType,
newResponseSize: MessageResponseType,
newCookieSize: MessageCookieReplyType,
newTransportSize: MessageTransportType,
}
}
device.awg.IsOn.SetTo(isAwgOn)
device.awg.JunkCreator = awg.NewJunkCreator(device.awg.Cfg)
if tempAwg.HandshakeHandler.IsSet {
if err := tempAwg.HandshakeHandler.Validate(); err != nil {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
} else {
device.awg.HandshakeHandler = tempAwg.HandshakeHandler
device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.Cfg.JunkPacketCount
device.version = VersionAwgSpecialHandshake
}
} else {
device.version = VersionAwg
}
device.awg.Mux.Unlock()
return errors.Join(errs...)
}
func (device *Device) ProcessAWGPacket(size int, packet *[]byte, buffer *[MaxMessageSize]byte) (uint32, error) {
// TODO:
// if awg.WaitResponse.ShouldWait.IsSet() {
// awg.WaitResponse.Channel <- struct{}{}
// }
expectedMsgType, isKnownSize := packetSizeToMsgType[size]
if !isKnownSize {
msgType, err := device.handleTransport(size, packet, buffer)
if err != nil {
return 0, fmt.Errorf("handle transport: %w", err)
}
return msgType, nil
}
junkSize := msgTypeToJunkSize[expectedMsgType]
// transport size can align with other header types;
// making sure we have the right actualMsgType
actualMsgType, err := device.getMsgType(packet, junkSize)
if err != nil {
return 0, fmt.Errorf("get msg type: %w", err)
}
if actualMsgType == expectedMsgType {
*packet = (*packet)[junkSize:]
return actualMsgType, nil
}
device.log.Verbosef("awg: transport packet lined up with another msg type")
msgType, err := device.handleTransport(size, packet, buffer)
if err != nil {
return 0, fmt.Errorf("handle transport: %w", err)
}
return msgType, nil
}
func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) {
msgTypeValue := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4])
msgType, err := device.awg.GetMagicHeaderMinFor(msgTypeValue)
if err != nil {
return 0, fmt.Errorf("get magic header min: %w", err)
}
return msgType, nil
}
func (device *Device) handleTransport(size int, packet *[]byte, buffer *[MaxMessageSize]byte) (uint32, error) {
junkSize := device.awg.Cfg.TransportHeaderJunkSize
msgType, err := device.getMsgType(packet, junkSize)
if err != nil {
return 0, fmt.Errorf("get msg type: %w", err)
}
if msgType != MessageTransportType {
// probably a junk packet
return 0, fmt.Errorf("Received message with unknown type: %d", msgType)
}
if junkSize > 0 {
// remove junk from buffer by shifting the packet
// this buffer is also used for decryption, so it needs to be corrected
copy((*buffer)[:size], (*packet)[junkSize:])
size -= junkSize
// need to reinitialize packet as well
(*packet) = (*packet)[:size]
}
return msgType, nil
}

63
device/magic-header.go Normal file
View File

@@ -0,0 +1,63 @@
package device
import (
"crypto/rand"
"errors"
"fmt"
"math/big"
"strconv"
"strings"
)
type magicHeader struct {
start uint32
end uint32
}
func newMagicHeader(spec string) (*magicHeader, error) {
parts := strings.Split(spec, "-")
if len(parts) < 1 || len(parts) > 2 {
return nil, errors.New("bad format")
}
start, err := strconv.ParseUint(parts[0], 10, 32)
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", parts[0], err)
}
var end uint64
if len(parts) > 1 {
end, err = strconv.ParseUint(parts[1], 10, 32)
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", parts[1], err)
}
} else {
end = start
}
if end < start {
return nil, errors.New("wrong range specified")
}
return &magicHeader{
start: uint32(start),
end: uint32(end),
}, nil
}
func (h *magicHeader) GenSpec() string {
if h.start == h.end {
return fmt.Sprintf("%d", h.start)
}
return fmt.Sprintf("%d-%d", h.start, h.end)
}
func (h *magicHeader) Validate(val uint32) bool {
return h.start <= val && val <= h.end
}
func (h *magicHeader) Generate() uint32 {
high := int64(h.end - h.start + 1)
r, _ := rand.Int(rand.Reader, big.NewInt(high))
return h.start + uint32(r.Int64())
}

View File

@@ -53,17 +53,11 @@ const (
)
const (
DefaultMessageInitiationType uint32 = 1
DefaultMessageResponseType uint32 = 2
DefaultMessageCookieReplyType uint32 = 3
DefaultMessageTransportType uint32 = 4
)
var (
MessageInitiationType uint32 = DefaultMessageInitiationType
MessageResponseType uint32 = DefaultMessageResponseType
MessageCookieReplyType uint32 = DefaultMessageCookieReplyType
MessageTransportType uint32 = DefaultMessageTransportType
MessageUnknownType uint32 = 0
MessageInitiationType uint32 = 1
MessageResponseType uint32 = 2
MessageCookieReplyType uint32 = 3
MessageTransportType uint32 = 4
)
const (
@@ -82,11 +76,6 @@ const (
MessageTransportOffsetContent = 16
)
var (
packetSizeToMsgType map[int]uint32
msgTypeToJunkSize map[uint32]int
)
/* Type is an 8-bit field, followed by 3 nul bytes,
* by marshalling the messages in little-endian byteorder
* we can treat these as a 32-bit unsigned int (for now)
@@ -205,17 +194,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(handshake.remoteStatic[:])
msgType := DefaultMessageInitiationType
if device.isAWG() {
device.awg.Mux.RLock()
msgType, err = device.awg.GetMsgType(DefaultMessageInitiationType)
if err != nil {
device.awg.Mux.RUnlock()
return nil, fmt.Errorf("get message type: %w", err)
}
device.awg.Mux.RUnlock()
}
msgType := device.headers.init.Generate()
msg := MessageInitiation{
Type: msgType,
@@ -274,13 +253,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
chainKey [blake2s.Size]byte
)
device.awg.Mux.RLock()
if msg.Type != MessageInitiationType {
device.awg.Mux.RUnlock()
return nil
}
device.awg.Mux.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -395,19 +370,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
var msg MessageResponse
if device.isAWG() {
device.awg.Mux.RLock()
msg.Type, err = device.awg.GetMsgType(DefaultMessageResponseType)
if err != nil {
device.awg.Mux.RUnlock()
return nil, fmt.Errorf("get message type: %w", err)
}
device.awg.Mux.RUnlock()
} else {
msg.Type = DefaultMessageResponseType
}
msg.Type = device.headers.response.Generate()
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
@@ -457,13 +420,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
device.awg.Mux.RLock()
if msg.Type != MessageResponseType {
device.awg.Mux.RUnlock()
return nil
}
device.awg.Mux.RUnlock()
// lookup handshake by receiver

140
device/obf.go Normal file
View File

@@ -0,0 +1,140 @@
package device
import (
"errors"
"fmt"
"strings"
)
type obfBuilder func(val string) (obf, error)
var obfBuilders = map[string]obfBuilder{
"b": newBytesObf,
"t": newTimestampObf,
"r": newRandObf,
"rc": newRandCharObf,
"rd": newRandDigitsObf,
"d": newDataObf,
"ds": newDataStringObf,
"dz": newDataSizeObf,
}
type obf interface {
Obfuscate(dst, src []byte)
Deobfuscate(dst, src []byte) bool
ObfuscatedLen(srcLen int) int
DeobfuscatedLen(srcLen int) int
}
type obfChain struct {
Spec string
obfs []obf
}
func newObfChain(spec string) (*obfChain, error) {
var (
obfs []obf
errs []error
)
remaining := spec[:]
for {
start := strings.IndexByte(remaining, '<')
if start == -1 {
break
}
end := strings.IndexByte(remaining[start:], '>')
if end == -1 {
return nil, errors.New("missing enclosing >")
}
end += start
tag := remaining[start+1 : end]
parts := strings.Fields(tag)
if len(parts) == 0 {
errs = append(errs, errors.New("empty tag"))
remaining = remaining[end+1:]
continue
}
key := parts[0]
builder, ok := obfBuilders[key]
if !ok {
errs = append(errs, fmt.Errorf("unknown tag <%s>", key))
remaining = remaining[end+1:]
continue
}
val := ""
if len(parts) > 1 {
val = parts[1]
}
o, err := builder(val)
if err != nil {
errs = append(errs, fmt.Errorf("failed to build <%s>: %w", key, err))
remaining = remaining[end+1:]
continue
}
obfs = append(obfs, o)
remaining = remaining[end+1:]
}
if len(errs) > 0 {
return nil, errors.Join(errs...)
}
return &obfChain{
Spec: spec,
obfs: obfs,
}, nil
}
func (c *obfChain) Obfuscate(dst, src []byte) {
written := 0
for _, o := range c.obfs {
obfLen := o.ObfuscatedLen(len(src))
o.Obfuscate(dst[written:written+obfLen], src)
written += obfLen
}
}
func (c *obfChain) Deobfuscate(dst, src []byte) bool {
dynamicLen := len(src) - c.ObfuscatedLen(0)
written, read := 0, 0
for _, o := range c.obfs {
deobfLen := o.DeobfuscatedLen(dynamicLen)
obfLen := o.ObfuscatedLen(deobfLen)
if !o.Deobfuscate(dst[written:written+deobfLen], src[read:read+obfLen]) {
return false
}
written += deobfLen
read += obfLen
}
return true
}
func (c *obfChain) ObfuscatedLen(n int) int {
total := 0
for _, o := range c.obfs {
total += o.ObfuscatedLen(n)
}
return total
}
func (c *obfChain) DeobfuscatedLen(n int) int {
dynamicLen := n - c.ObfuscatedLen(0)
total := 0
for _, o := range c.obfs {
total += o.DeobfuscatedLen(dynamicLen)
}
return total
}

47
device/obf_bytes.go Normal file
View File

@@ -0,0 +1,47 @@
package device
import (
"bytes"
"encoding/hex"
"errors"
"strings"
)
func newBytesObf(val string) (obf, error) {
val = strings.TrimPrefix(val, "0x")
if len(val) == 0 {
return nil, errors.New("empty argument")
}
if len(val)%2 != 0 {
return nil, errors.New("odd amount of symbols")
}
bytes, err := hex.DecodeString(val)
if err != nil {
return nil, err
}
return &bytesObf{data: bytes}, nil
}
type bytesObf struct {
data []byte
}
func (o *bytesObf) Obfuscate(dst, src []byte) {
copy(dst, o.data)
}
func (o *bytesObf) Deobfuscate(dst, src []byte) bool {
return bytes.Equal(o.data, src[:o.ObfuscatedLen(0)])
}
func (o *bytesObf) ObfuscatedLen(srcLen int) int {
return len(o.data)
}
func (o *bytesObf) DeobfuscatedLen(srcLen int) int {
return 0
}

25
device/obf_data.go Normal file
View File

@@ -0,0 +1,25 @@
package device
func newDataObf(val string) (obf, error) {
return &dataObf{}, nil
}
type dataObf struct {
}
func (obf *dataObf) Obfuscate(dst, src []byte) {
copy(dst, src)
}
func (obf *dataObf) Deobfuscate(dst, src []byte) bool {
copy(dst, src)
return true
}
func (o *dataObf) ObfuscatedLen(n int) int {
return n
}
func (o *dataObf) DeobfuscatedLen(n int) int {
return n
}

38
device/obf_datasize.go Normal file
View File

@@ -0,0 +1,38 @@
package device
import "strconv"
func newDataSizeObf(val string) (obf, error) {
length, err := strconv.Atoi(val)
if err != nil {
return nil, err
}
return &dataSizeObf{
length: length,
}, nil
}
type dataSizeObf struct {
length int
}
func (o *dataSizeObf) Obfuscate(dst, src []byte) {
srcLen := len(src)
for i := o.length - 1; i >= 0; i-- {
dst[i] = byte(srcLen & 0xFF)
srcLen >>= 8
}
}
func (o *dataSizeObf) Deobfuscate(dst, src []byte) bool {
return true
}
func (o *dataSizeObf) ObfuscatedLen(srcLen int) int {
return o.length
}
func (o *dataSizeObf) DeobfuscatedLen(srcLen int) int {
return 0
}

29
device/obf_datastring.go Normal file
View File

@@ -0,0 +1,29 @@
package device
import (
"encoding/base64"
)
func newDataStringObf(val string) (obf, error) {
return &dataStringObf{}, nil
}
type dataStringObf struct {
}
func (o *dataStringObf) Obfuscate(dst, src []byte) {
base64.RawStdEncoding.Encode(dst, src)
}
func (o *dataStringObf) Deobfuscate(dst, src []byte) bool {
base64.RawStdEncoding.Decode(dst, src)
return true
}
func (o *dataStringObf) ObfuscatedLen(n int) int {
return base64.RawStdEncoding.EncodedLen(n)
}
func (o *dataStringObf) DeobfuscatedLen(n int) int {
return base64.RawStdEncoding.DecodedLen(n)
}

39
device/obf_rand.go Normal file
View File

@@ -0,0 +1,39 @@
package device
import (
"crypto/rand"
"strconv"
)
func newRandObf(val string) (obf, error) {
length, err := strconv.Atoi(val)
if err != nil {
return nil, err
}
return &randObf{
length: length,
}, nil
}
type randObf struct {
length int
}
func (o *randObf) Obfuscate(dst, src []byte) {
rand.Read(dst[:o.length])
}
func (o *randObf) Deobfuscate(dst, src []byte) bool {
// there is no way to validate randomness :)
// assume that it is always true
return true
}
func (o *randObf) ObfuscatedLen(n int) int {
return o.length
}
func (o *randObf) DeobfuscatedLen(n int) int {
return 0
}

48
device/obf_randchars.go Normal file
View File

@@ -0,0 +1,48 @@
package device
import (
"crypto/rand"
"strconv"
"unicode"
)
const chars52 = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func newRandCharObf(val string) (obf, error) {
length, err := strconv.Atoi(val)
if err != nil {
return nil, err
}
return &randCharObf{
length: length,
}, nil
}
type randCharObf struct {
length int
}
func (o *randCharObf) Obfuscate(dst, src []byte) {
rand.Read(dst[:o.length])
for i := range dst[:o.length] {
dst[i] = chars52[dst[i]%52]
}
}
func (o *randCharObf) Deobfuscate(dst, src []byte) bool {
for _, b := range src[:o.length] {
if !unicode.IsLetter(rune(b)) {
return false
}
}
return true
}
func (o *randCharObf) ObfuscatedLen(n int) int {
return o.length
}
func (o *randCharObf) DeobfuscatedLen(n int) int {
return 0
}

48
device/obf_randdigits.go Normal file
View File

@@ -0,0 +1,48 @@
package device
import (
"crypto/rand"
"strconv"
"unicode"
)
const digits10 = "0123456789"
func newRandDigitsObf(val string) (obf, error) {
length, err := strconv.Atoi(val)
if err != nil {
return nil, err
}
return &randDigitObf{
length: length,
}, nil
}
type randDigitObf struct {
length int
}
func (o *randDigitObf) Obfuscate(dst, src []byte) {
rand.Read(dst[:o.length])
for i := range dst[:o.length] {
dst[i] = digits10[dst[i]%10]
}
}
func (o *randDigitObf) Deobfuscate(dst, src []byte) bool {
for _, b := range src[:o.length] {
if !unicode.IsDigit(rune(b)) {
return false
}
}
return true
}
func (o *randDigitObf) ObfuscatedLen(n int) int {
return o.length
}
func (o *randDigitObf) DeobfuscatedLen(n int) int {
return 0
}

31
device/obf_timestamp.go Normal file
View File

@@ -0,0 +1,31 @@
package device
import (
"encoding/binary"
"time"
)
func newTimestampObf(_ string) (obf, error) {
return &timestampObf{}, nil
}
type timestampObf struct{}
func (o *timestampObf) Obfuscate(dst, src []byte) {
t := uint32(time.Now().Unix())
binary.BigEndian.PutUint32(dst, t)
}
func (o *timestampObf) Deobfuscate(dst, src []byte) bool {
// replay attack check?
// requires time to be always synchronized
return true
}
func (o *timestampObf) ObfuscatedLen(n int) int {
return 4
}
func (o *timestampObf) DeobfuscatedLen(n int) int {
return 0
}

View File

@@ -13,7 +13,6 @@ import (
"time"
"github.com/amnezia-vpn/amneziawg-go/conn"
"github.com/amnezia-vpn/amneziawg-go/device/awg"
)
type Peer struct {
@@ -114,16 +113,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return peer, nil
}
func (peer *Peer) SendAndCountBuffers(buffers [][]byte) error {
err := peer.SendBuffers(buffers)
if err == nil {
awg.PacketCounter.Add(uint64(len(buffers)))
return nil
}
return err
}
func (peer *Peer) SendBuffers(buffers [][]byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()

View File

@@ -97,13 +97,13 @@ func (device *Device) RoutineReceiveIncoming(
elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
)
for i := range bufsArrs {
for i := range maxBatchSize {
bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
}
defer func() {
for i := 0; i < maxBatchSize; i++ {
for i := range maxBatchSize {
if bufsArrs[i] != nil {
device.PutMessageBuffer(bufsArrs[i])
}
@@ -129,7 +129,6 @@ func (device *Device) RoutineReceiveIncoming(
}
deathSpiral = 0
device.awg.Mux.RLock()
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
@@ -138,16 +137,12 @@ func (device *Device) RoutineReceiveIncoming(
// check size of packet
packet := bufsArrs[i][:size]
var msgType uint32
if device.isAWG() {
msgType, err = device.ProcessAWGPacket(size, &packet, bufsArrs[i])
if err != nil {
device.log.Verbosef("awg: process packet: %v", err)
continue
}
} else {
msgType = binary.LittleEndian.Uint32(packet[:4])
// get message padding and type based on information from S1-S4 and H1-H4
msgType, padding := device.DeterminePacketTypeAndPadding(packet, MessageUnknownType)
if padding > 0 {
copy(packet, packet[padding:])
packet = packet[:len(packet)-padding]
}
switch msgType {
@@ -233,7 +228,6 @@ func (device *Device) RoutineReceiveIncoming(
default:
}
}
device.awg.Mux.RUnlock()
for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elemsContainer
@@ -291,9 +285,6 @@ func (device *Device) RoutineHandshake(id int) {
device.log.Verbosef("Routine: handshake worker %d - started", id)
for elem := range device.queue.handshake.c {
device.awg.Mux.RLock()
// handle cookie fields and ratelimiting
switch elem.msgType {
@@ -450,7 +441,6 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
device.awg.Mux.RUnlock()
device.PutMessageBuffer(elem.buffer)
}
}
@@ -569,3 +559,57 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device.PutInboundElementsContainer(elemsContainer)
}
}
func (device *Device) DeterminePacketTypeAndPadding(packet []byte, expectedType uint32) (uint32, int) {
size := len(packet)
if expectedType == MessageUnknownType || expectedType == MessageInitiationType {
padding := device.paddings.init
header := device.headers.init
if size == padding+MessageInitiationSize {
data := packet[padding:]
if header.Validate(binary.LittleEndian.Uint32(data)) {
return MessageInitiationType, padding
}
}
}
if expectedType == MessageUnknownType || expectedType == MessageResponseType {
padding := device.paddings.response
header := device.headers.response
if size == padding+MessageResponseSize {
data := packet[padding:]
if header.Validate(binary.LittleEndian.Uint32(data)) {
return MessageResponseType, padding
}
}
}
if expectedType == MessageUnknownType || expectedType == MessageCookieReplyType {
padding := device.paddings.cookie
header := device.headers.cookie
if size == padding+MessageCookieReplySize {
data := packet[padding:]
if header.Validate(binary.LittleEndian.Uint32(data)) {
return MessageCookieReplyType, padding
}
}
}
if expectedType == MessageUnknownType || expectedType == MessageTransportType {
padding := device.paddings.transport
header := device.headers.transport
if size >= padding+MessageTransportHeaderSize {
data := packet[padding:]
if header.Validate(binary.LittleEndian.Uint32(data)) {
return MessageTransportType, padding
}
}
}
return MessageUnknownType, 0
}

View File

@@ -7,8 +7,10 @@ package device
import (
"bytes"
"crypto/rand"
"encoding/binary"
"errors"
"math/big"
"net"
"os"
"sync"
@@ -123,41 +125,28 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
return err
}
var sendBuffer [][]byte
// so only packet processed for cookie generation
var junkedHeader []byte
if peer.device.version >= VersionAwg {
var junks [][]byte
if peer.device.version == VersionAwgSpecialHandshake {
peer.device.awg.Mux.RLock()
// set junks depending on packet type
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
if junks != nil {
peer.device.log.Verbosef("%v - Special junks sent", peer)
}
peer.device.awg.Mux.RUnlock()
} else {
junks = make([][]byte, 0, peer.device.awg.Cfg.JunkPacketCount)
for _, ipacket := range peer.device.ipackets {
if ipacket != nil {
buf := make([]byte, ipacket.ObfuscatedLen(0))
ipacket.Obfuscate(buf, nil)
sendBuffer = append(sendBuffer, buf)
}
peer.device.awg.Mux.RLock()
peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
peer.device.awg.Mux.RUnlock()
}
if len(junks) > 0 {
err = peer.SendBuffers(junks)
jc := peer.device.junk.count
jmin := peer.device.junk.min
jmax := peer.device.junk.max
if err != nil {
peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err)
return err
}
}
for i := 0; i < jc; i++ {
nBig, _ := rand.Int(rand.Reader, big.NewInt(int64(jmax-jmin+1)))
n := int(nBig.Int64()) + jmin
junkedHeader, err = peer.device.awg.CreateInitHeaderJunk()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
buf := make([]byte, n)
rand.Read(buf)
sendBuffer = append(sendBuffer, buf)
}
var buf [MessageInitiationSize]byte
@@ -165,14 +154,20 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
sendBuffer = append(sendBuffer, junkedHeader)
if padding := peer.device.paddings.init; padding > 0 {
buf := make([]byte, padding+len(packet))
rand.Read(buf[:padding])
copy(buf[padding:], packet)
packet = buf
}
err = peer.SendAndCountBuffers(sendBuffer)
sendBuffer = append(sendBuffer, packet)
err = peer.SendBuffers(sendBuffer)
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
@@ -194,19 +189,12 @@ func (peer *Peer) SendHandshakeResponse() error {
return err
}
junkedHeader, err := peer.device.awg.CreateResponseHeaderJunk()
if err != nil {
peer.device.log.Errorf("%v - %v", peer, err)
return err
}
var buf [MessageResponseSize]byte
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
junkedHeader = append(junkedHeader, packet...)
err = peer.BeginSymmetricSession()
if err != nil {
@@ -218,32 +206,26 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
if padding := peer.device.paddings.response; padding > 0 {
buf := make([]byte, padding+len(packet))
rand.Read(buf[:padding])
copy(buf[padding:], packet)
packet = buf
}
// TODO: allocation could be avoided
err = peer.SendAndCountBuffers([][]byte{junkedHeader})
err = peer.SendBuffers([][]byte{packet})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
}
return err
}
func (device *Device) SendHandshakeCookie(
initiatingElem *QueueHandshakeElement,
) error {
func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
msgType := DefaultMessageCookieReplyType
if device.isAWG() {
device.awg.Mux.RLock()
var err error
msgType, err = device.awg.GetMsgType(DefaultMessageCookieReplyType)
device.awg.Mux.RUnlock()
if err != nil {
device.log.Errorf("Get message type for cookie reply: %v", err)
return err
}
}
msgType := device.headers.cookie.Generate()
reply, err := device.cookieChecker.CreateReply(
initiatingElem.packet,
@@ -256,19 +238,20 @@ func (device *Device) SendHandshakeCookie(
return err
}
junkedHeader, err := device.awg.CreateCookieReplyHeaderJunk()
if err != nil {
device.log.Errorf("%v - %v", device, err)
return err
}
var buf [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
packet := writer.Bytes()
if padding := device.paddings.cookie; padding > 0 {
buf := make([]byte, padding+len(packet))
rand.Read(buf[:padding])
copy(buf[padding:], packet)
packet = buf
}
junkedHeader = append(junkedHeader, writer.Bytes()...)
// TODO: allocation could be avoided
device.net.bind.Send([][]byte{junkedHeader}, initiatingElem.endpoint)
device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint)
return nil
}
@@ -532,18 +515,7 @@ func (device *Device) RoutineEncryption(id int) {
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
msgType := DefaultMessageTransportType
if device.isAWG() {
device.awg.Mux.RLock()
var err error
msgType, err = device.awg.GetMsgType(DefaultMessageTransportType)
device.awg.Mux.RUnlock()
if err != nil {
device.log.Errorf("get message type for transport: %v", err)
continue
}
}
msgType := device.headers.transport.Generate()
binary.LittleEndian.PutUint32(fieldType, msgType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
@@ -603,13 +575,15 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
if len(elem.packet) != MessageKeepaliveSize {
dataSent = true
junkedHeader, err := device.awg.CreateTransportHeaderJunk(len(elem.packet))
if err != nil {
device.log.Errorf("%v - %v", device, err)
continue
if padding := device.paddings.transport; padding > 0 {
// elem.packet is stored at the start of elem.buffer
// with zero padding
for i := len(elem.packet) - 1; i >= 0; i-- {
elem.buffer[i+padding] = elem.buffer[i]
}
rand.Read(elem.buffer[:padding])
elem.packet = elem.buffer[:padding+len(elem.packet)]
}
elem.packet = append(junkedHeader, elem.packet...)
}
bufs = append(bufs, elem.packet)
}
@@ -617,7 +591,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
err := peer.SendAndCountBuffers(bufs)
err := peer.SendBuffers(bufs)
if dataSent {
peer.timersDataSent()
}

View File

@@ -18,7 +18,6 @@ import (
"sync"
"time"
"github.com/amnezia-vpn/amneziawg-go/device/awg"
"github.com/amnezia-vpn/amneziawg-go/ipc"
)
@@ -98,42 +97,53 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark)
}
if device.isAWG() {
if device.awg.Cfg.JunkPacketCount != 0 {
sendf("jc=%d", device.awg.Cfg.JunkPacketCount)
}
if device.awg.Cfg.JunkPacketMinSize != 0 {
sendf("jmin=%d", device.awg.Cfg.JunkPacketMinSize)
}
if device.awg.Cfg.JunkPacketMaxSize != 0 {
sendf("jmax=%d", device.awg.Cfg.JunkPacketMaxSize)
}
if device.awg.Cfg.InitHeaderJunkSize != 0 {
sendf("s1=%d", device.awg.Cfg.InitHeaderJunkSize)
}
if device.awg.Cfg.ResponseHeaderJunkSize != 0 {
sendf("s2=%d", device.awg.Cfg.ResponseHeaderJunkSize)
}
if device.awg.Cfg.CookieReplyHeaderJunkSize != 0 {
sendf("s3=%d", device.awg.Cfg.CookieReplyHeaderJunkSize)
}
if device.awg.Cfg.TransportHeaderJunkSize != 0 {
sendf("s4=%d", device.awg.Cfg.TransportHeaderJunkSize)
}
for i, magicHeader := range device.awg.Cfg.MagicHeaders.Values {
if magicHeader.Min > 4 {
if magicHeader.Min == magicHeader.Max {
sendf("h%d=%d", i+1, magicHeader.Min)
continue
}
if device.junk.count != 0 {
sendf("jc=%d", device.junk.count)
}
sendf("h%d=%d-%d", i+1, magicHeader.Min, magicHeader.Max)
}
}
if device.junk.min != 0 {
sendf("jmin=%d", device.junk.min)
}
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
for _, field := range specialJunkIpcFields {
sendf("%s=%s", field.Key, field.Value)
if device.junk.max != 0 {
sendf("jmax=%d", device.junk.max)
}
if device.paddings.init != 0 {
sendf("s1=%d", device.paddings.init)
}
if device.paddings.response != 0 {
sendf("s2=%d", device.paddings.response)
}
if device.paddings.cookie != 0 {
sendf("s3=%d", device.paddings.cookie)
}
if device.paddings.transport != 0 {
sendf("s4=%d", device.paddings.transport)
}
if device.headers.init != nil {
sendf("h1=%s", device.headers.init.GenSpec())
}
if device.headers.response != nil {
sendf("h2=%s", device.headers.response.GenSpec())
}
if device.headers.cookie != nil {
sendf("h3=%s", device.headers.cookie.GenSpec())
}
if device.headers.transport != nil {
sendf("h4=%s", device.headers.transport.GenSpec())
}
for i, ipacket := range device.ipackets {
if ipacket != nil {
sendf("i%d=%s", i+1, ipacket.Spec)
}
}
@@ -187,20 +197,18 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
}
}()
ipcDev := new(ipcSetDevice)
peer := new(ipcSetPeer)
deviceConfig := true
tempAwg := awg.Protocol{}
tempAwg.Cfg.MagicHeaders.Values = make([]awg.MagicHeader, 4)
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
// Blank line means terminate operation.
err := device.handlePostConfig(&tempAwg)
err := ipcDev.mergeWithDevice(device)
if err != nil {
return err
return ipcErrorf(ipc.IpcErrorInvalid, "failed to merge with device: %w", err)
}
peer.handlePostConfig()
return nil
@@ -229,7 +237,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
var err error
if deviceConfig {
err = device.handleDeviceLine(key, value, &tempAwg)
err = device.handleDeviceLine(key, value)
} else {
err = device.handlePeerLine(peer, key, value)
}
@@ -237,9 +245,9 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return err
}
}
err = device.handlePostConfig(&tempAwg)
err = ipcDev.mergeWithDevice(device)
if err != nil {
return err
return ipcErrorf(ipc.IpcErrorInvalid, "failed to merge with device: %w", err)
}
peer.handlePostConfig()
@@ -249,7 +257,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return nil
}
func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) error {
func (device *Device) handleDeviceLine(key, value string) error {
switch key {
case "private_key":
var sk NoisePrivateKey
@@ -300,112 +308,145 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
device.RemoveAllPeers()
case "jc":
junkPacketCount, err := strconv.Atoi(value)
jc, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse jc: %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
tempAwg.Cfg.JunkPacketCount = junkPacketCount
tempAwg.Cfg.IsSet = true
if jc <= 0 {
return ipcErrorf(ipc.IpcErrorInvalid, "jc must be a positive value")
}
device.log.Verbosef("UAPI: Updating junk count")
device.junk.count = jc
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
jmin, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse jmin: %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
tempAwg.Cfg.JunkPacketMinSize = junkPacketMinSize
tempAwg.Cfg.IsSet = true
if jmin <= 0 {
return ipcErrorf(ipc.IpcErrorInvalid, "jmin must be a positive value")
}
device.log.Verbosef("UAPI: Updating junk min")
device.junk.min = jmin
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
jmax, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse jmax: %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
tempAwg.Cfg.JunkPacketMaxSize = junkPacketMaxSize
tempAwg.Cfg.IsSet = true
if jmax <= 0 {
return ipcErrorf(ipc.IpcErrorInvalid, "jmax must be a positive value")
}
device.log.Verbosef("UAPI: Updating junk max")
device.junk.max = jmax
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
padding, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse s1: %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
tempAwg.Cfg.InitHeaderJunkSize = initPacketJunkSize
tempAwg.Cfg.IsSet = true
if padding < 0 {
return ipcErrorf(ipc.IpcErrorInvalid, "s1 must be non-negative")
}
device.log.Verbosef("UAPI: Updating s1 padding")
device.paddings.init = padding
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
padding, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse s2: %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
tempAwg.Cfg.ResponseHeaderJunkSize = responsePacketJunkSize
tempAwg.Cfg.IsSet = true
if padding < 0 {
return ipcErrorf(ipc.IpcErrorInvalid, "s2 must be non-negative")
}
device.log.Verbosef("UAPI: Updating s2 padding")
device.paddings.response = padding
case "s3":
cookieReplyPacketJunkSize, err := strconv.Atoi(value)
padding, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse s3: %w", err)
}
device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size")
tempAwg.Cfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize
tempAwg.Cfg.IsSet = true
if padding < 0 {
return ipcErrorf(ipc.IpcErrorInvalid, "s3 must be non-negative")
}
device.log.Verbosef("UAPI: Updating s3 padding")
device.paddings.cookie = padding
case "s4":
transportPacketJunkSize, err := strconv.Atoi(value)
padding, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse s4: %w", err)
}
device.log.Verbosef("UAPI: Updating transport_packet_junk_size")
tempAwg.Cfg.TransportHeaderJunkSize = transportPacketJunkSize
tempAwg.Cfg.IsSet = true
if padding < 0 {
return ipcErrorf(ipc.IpcErrorInvalid, "s4 must be non-negative")
}
device.log.Verbosef("UAPI: Updating s4 padding")
device.paddings.transport = padding
case "h1":
initMagicHeader, err := awg.ParseMagicHeader(key, value)
header, err := newMagicHeader(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse H1: %w", err)
}
device.headers.init = header
tempAwg.Cfg.MagicHeaders.Values[0] = initMagicHeader
tempAwg.Cfg.IsSet = true
case "h2":
responseMagicHeader, err := awg.ParseMagicHeader(key, value)
header, err := newMagicHeader(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse H2: %w", err)
}
device.headers.response = header
tempAwg.Cfg.MagicHeaders.Values[1] = responseMagicHeader
tempAwg.Cfg.IsSet = true
case "h3":
cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value)
header, err := newMagicHeader(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse H3: %w", err)
}
device.headers.cookie = header
tempAwg.Cfg.MagicHeaders.Values[2] = cookieReplyMagicHeader
tempAwg.Cfg.IsSet = true
case "h4":
transportMagicHeader, err := awg.ParseMagicHeader(key, value)
header, err := newMagicHeader(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse H4: %w", err)
}
device.headers.transport = header
tempAwg.Cfg.MagicHeaders.Values[3] = transportMagicHeader
tempAwg.Cfg.IsSet = true
case "i1", "i2", "i3", "i4", "i5":
if len(value) == 0 {
device.log.Verbosef("UAPI: received empty %s", key)
return nil
}
generators, err := awg.ParseTagJunkGenerator(key, value)
case "i1":
chain, err := newObfChain(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse I1: %w", err)
}
device.log.Verbosef("UAPI: Updating %s", key)
tempAwg.HandshakeHandler.SpecialJunk.AppendGenerator(generators)
tempAwg.HandshakeHandler.IsSet = true
device.ipackets[0] = chain
case "i2":
chain, err := newObfChain(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse I2: %w", err)
}
device.ipackets[1] = chain
case "i3":
chain, err := newObfChain(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse I3: %w", err)
}
device.ipackets[2] = chain
case "i4":
chain, err := newObfChain(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse I4: %w", err)
}
device.ipackets[3] = chain
case "i5":
chain, err := newObfChain(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse I5: %w", err)
}
device.ipackets[4] = chain
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
}
@@ -654,3 +695,49 @@ func (device *Device) IpcHandle(socket net.Conn) {
buffered.Flush()
}
}
type ipcSetDevice struct {
headers struct {
init *magicHeader
response *magicHeader
cookie *magicHeader
transport *magicHeader
}
}
func (d *ipcSetDevice) mergeWithDevice(device *Device) error {
if d.headers.init == nil {
d.headers.init = device.headers.init
}
if d.headers.response == nil {
d.headers.response = device.headers.response
}
if d.headers.cookie == nil {
d.headers.cookie = device.headers.cookie
}
if d.headers.transport == nil {
d.headers.transport = device.headers.transport
}
headers := []*magicHeader{d.headers.init, d.headers.response, d.headers.cookie, d.headers.transport}
for i := 0; i < len(headers); i++ {
for j := i + 1; j < len(headers); j++ {
left := headers[i]
right := headers[j]
if left.start <= right.end && right.start <= left.end {
return errors.New("headers must not overlap")
}
}
}
device.headers.init = d.headers.init
device.headers.response = d.headers.response
device.headers.cookie = d.headers.cookie
device.headers.transport = d.headers.transport
return nil
}