mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2026-05-17 00:05:50 +03:00
fix: refactor processing of junk packets (#103)
- fix the bug that transport packet interprets as init/resp/cookie with the same size - cleanup error responses - reduce buffer allocations
This commit is contained in:
@@ -1,90 +0,0 @@
|
||||
package awg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/tevino/abool"
|
||||
)
|
||||
|
||||
type Cfg struct {
|
||||
IsSet bool
|
||||
JunkPacketCount int
|
||||
JunkPacketMinSize int
|
||||
JunkPacketMaxSize int
|
||||
InitHeaderJunkSize int
|
||||
ResponseHeaderJunkSize int
|
||||
CookieReplyHeaderJunkSize int
|
||||
TransportHeaderJunkSize int
|
||||
|
||||
MagicHeaders MagicHeaders
|
||||
}
|
||||
|
||||
type Protocol struct {
|
||||
IsOn abool.AtomicBool
|
||||
// TODO: revision the need of the mutex
|
||||
Mux sync.RWMutex
|
||||
Cfg Cfg
|
||||
JunkCreator JunkCreator
|
||||
|
||||
HandshakeHandler SpecialHandshakeHandler
|
||||
}
|
||||
|
||||
func (protocol *Protocol) CreateInitHeaderJunk() ([]byte, error) {
|
||||
protocol.Mux.RLock()
|
||||
defer protocol.Mux.RUnlock()
|
||||
|
||||
return protocol.createHeaderJunk(protocol.Cfg.InitHeaderJunkSize, 0)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) CreateResponseHeaderJunk() ([]byte, error) {
|
||||
protocol.Mux.RLock()
|
||||
defer protocol.Mux.RUnlock()
|
||||
|
||||
return protocol.createHeaderJunk(protocol.Cfg.ResponseHeaderJunkSize, 0)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) CreateCookieReplyHeaderJunk() ([]byte, error) {
|
||||
protocol.Mux.RLock()
|
||||
defer protocol.Mux.RUnlock()
|
||||
|
||||
return protocol.createHeaderJunk(protocol.Cfg.CookieReplyHeaderJunkSize, 0)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) CreateTransportHeaderJunk(packetSize int) ([]byte, error) {
|
||||
protocol.Mux.RLock()
|
||||
defer protocol.Mux.RUnlock()
|
||||
|
||||
return protocol.createHeaderJunk(protocol.Cfg.TransportHeaderJunkSize, packetSize)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) createHeaderJunk(junkSize int, extraSize int) ([]byte, error) {
|
||||
if junkSize == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
buf := make([]byte, 0, junkSize+extraSize)
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
|
||||
err := protocol.JunkCreator.AppendJunk(writer, junkSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("append junk: %w", err)
|
||||
}
|
||||
|
||||
return writer.Bytes(), nil
|
||||
}
|
||||
|
||||
func (protocol *Protocol) GetMagicHeaderMinFor(msgType uint32) (uint32, error) {
|
||||
for _, magicHeader := range protocol.Cfg.MagicHeaders.Values {
|
||||
if magicHeader.Min <= msgType && msgType <= magicHeader.Max {
|
||||
return magicHeader.Min, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no header for value: %d", msgType)
|
||||
}
|
||||
|
||||
func (protocol *Protocol) GetMsgType(defaultMsgType uint32) (uint32, error) {
|
||||
return protocol.Cfg.MagicHeaders.Get(defaultMsgType)
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package internal
|
||||
|
||||
type mockGenerator struct {
|
||||
size int
|
||||
}
|
||||
|
||||
func NewMockGenerator(size int) mockGenerator {
|
||||
return mockGenerator{size: size}
|
||||
}
|
||||
|
||||
func (m mockGenerator) Generate() []byte {
|
||||
return make([]byte, m.size)
|
||||
}
|
||||
|
||||
func (m mockGenerator) Size() int {
|
||||
return m.size
|
||||
}
|
||||
|
||||
func (m mockGenerator) Name() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
type mockByteGenerator struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
func NewMockByteGenerator(data []byte) mockByteGenerator {
|
||||
return mockByteGenerator{data: data}
|
||||
}
|
||||
|
||||
func (bg mockByteGenerator) Generate() []byte {
|
||||
return bg.data
|
||||
}
|
||||
|
||||
func (bg mockByteGenerator) Size() int {
|
||||
return len(bg.data)
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
package awg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type JunkCreator struct {
|
||||
cfg Cfg
|
||||
randomGenerator PRNG[int]
|
||||
}
|
||||
|
||||
// TODO: refactor param to only pass the junk related params
|
||||
func NewJunkCreator(cfg Cfg) JunkCreator {
|
||||
return JunkCreator{cfg: cfg, randomGenerator: NewPRNG[int]()}
|
||||
}
|
||||
|
||||
// Should be called with awg mux RLocked
|
||||
func (jc *JunkCreator) CreateJunkPackets(junks *[][]byte) {
|
||||
if jc.cfg.JunkPacketCount == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for range jc.cfg.JunkPacketCount {
|
||||
packetSize := jc.randomPacketSize()
|
||||
junk := jc.randomJunkWithSize(packetSize)
|
||||
*junks = append(*junks, junk)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Should be called with awg mux RLocked
|
||||
func (jc *JunkCreator) randomPacketSize() int {
|
||||
return jc.randomGenerator.RandomSizeInRange(jc.cfg.JunkPacketMinSize, jc.cfg.JunkPacketMaxSize)
|
||||
}
|
||||
|
||||
// Should be called with awg mux RLocked
|
||||
func (jc *JunkCreator) AppendJunk(writer *bytes.Buffer, size int) error {
|
||||
headerJunk := jc.randomJunkWithSize(size)
|
||||
_, err := writer.Write(headerJunk)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write header junk: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Should be called with awg mux RLocked
|
||||
func (jc *JunkCreator) randomJunkWithSize(size int) []byte {
|
||||
return jc.randomGenerator.ReadSize(size)
|
||||
}
|
||||
@@ -1,97 +0,0 @@
|
||||
package awg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func setUpJunkCreator() JunkCreator {
|
||||
mh, _ := NewMagicHeaders(
|
||||
[]MagicHeader{
|
||||
NewMagicHeaderSameValue(123456),
|
||||
NewMagicHeaderSameValue(67543),
|
||||
NewMagicHeaderSameValue(32345),
|
||||
NewMagicHeaderSameValue(123123),
|
||||
},
|
||||
)
|
||||
|
||||
jc := NewJunkCreator(Cfg{
|
||||
IsSet: true,
|
||||
JunkPacketCount: 5,
|
||||
JunkPacketMinSize: 500,
|
||||
JunkPacketMaxSize: 1000,
|
||||
InitHeaderJunkSize: 30,
|
||||
ResponseHeaderJunkSize: 40,
|
||||
MagicHeaders: mh,
|
||||
})
|
||||
|
||||
return jc
|
||||
}
|
||||
|
||||
func Test_junkCreator_createJunkPackets(t *testing.T) {
|
||||
jc := setUpJunkCreator()
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
got := make([][]byte, 0, jc.cfg.JunkPacketCount)
|
||||
jc.CreateJunkPackets(&got)
|
||||
seen := make(map[string]bool)
|
||||
for _, junk := range got {
|
||||
key := string(junk)
|
||||
if seen[key] {
|
||||
t.Errorf(
|
||||
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
|
||||
got,
|
||||
junk,
|
||||
)
|
||||
return
|
||||
}
|
||||
seen[key] = true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
jc := setUpJunkCreator()
|
||||
r1 := jc.randomJunkWithSize(10)
|
||||
r2 := jc.randomJunkWithSize(10)
|
||||
fmt.Printf("%v\n%v\n", r1, r2)
|
||||
if bytes.Equal(r1, r2) {
|
||||
t.Errorf("same junks")
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_junkCreator_randomPacketSize(t *testing.T) {
|
||||
jc := setUpJunkCreator()
|
||||
for range [30]struct{}{} {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
if got := jc.randomPacketSize(); jc.cfg.JunkPacketMinSize > got ||
|
||||
got > jc.cfg.JunkPacketMaxSize {
|
||||
t.Errorf(
|
||||
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
|
||||
got,
|
||||
jc.cfg.JunkPacketMinSize,
|
||||
jc.cfg.JunkPacketMaxSize,
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_junkCreator_appendJunk(t *testing.T) {
|
||||
jc := setUpJunkCreator()
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
s := "apple"
|
||||
buffer := bytes.NewBuffer([]byte(s))
|
||||
err := jc.AppendJunk(buffer, 30)
|
||||
if err != nil &&
|
||||
buffer.Len() != len(s)+30 {
|
||||
t.Error("appendWithJunk() size don't match")
|
||||
}
|
||||
read := make([]byte, 50)
|
||||
buffer.Read(read)
|
||||
fmt.Println(string(read))
|
||||
})
|
||||
}
|
||||
@@ -1,97 +0,0 @@
|
||||
package awg
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type MagicHeader struct {
|
||||
Min uint32
|
||||
Max uint32
|
||||
}
|
||||
|
||||
func NewMagicHeaderSameValue(value uint32) MagicHeader {
|
||||
return MagicHeader{Min: value, Max: value}
|
||||
}
|
||||
|
||||
func NewMagicHeader(min, max uint32) (MagicHeader, error) {
|
||||
if min > max {
|
||||
return MagicHeader{}, fmt.Errorf("min (%d) cannot be greater than max (%d)", min, max)
|
||||
}
|
||||
|
||||
return MagicHeader{Min: min, Max: max}, nil
|
||||
}
|
||||
|
||||
func ParseMagicHeader(key, value string) (MagicHeader, error) {
|
||||
hyphenIdx := strings.Index(value, "-")
|
||||
if hyphenIdx == -1 {
|
||||
// if there is no hyphen, we treat it as single magic header value
|
||||
magicHeader, err := strconv.ParseUint(value, 10, 32)
|
||||
if err != nil {
|
||||
return MagicHeader{}, fmt.Errorf("parse key: %s; value: %s; %w", key, value, err)
|
||||
}
|
||||
|
||||
return NewMagicHeader(uint32(magicHeader), uint32(magicHeader))
|
||||
}
|
||||
|
||||
minStr := value[:hyphenIdx]
|
||||
maxStr := value[hyphenIdx+1:]
|
||||
if len(minStr) == 0 || len(maxStr) == 0 {
|
||||
return MagicHeader{}, fmt.Errorf("invalid value for key: %s; value: %s; expected format: min-max", key, value)
|
||||
}
|
||||
|
||||
min, err := strconv.ParseUint(minStr, 10, 32)
|
||||
if err != nil {
|
||||
return MagicHeader{}, fmt.Errorf("parse min key: %s; value: %s; %w", key, minStr, err)
|
||||
}
|
||||
|
||||
max, err := strconv.ParseUint(maxStr, 10, 32)
|
||||
if err != nil {
|
||||
return MagicHeader{}, fmt.Errorf("parse max key: %s; value: %s; %w", key, maxStr, err)
|
||||
}
|
||||
|
||||
magicHeader, err := NewMagicHeader(uint32(min), uint32(max))
|
||||
if err != nil {
|
||||
return MagicHeader{}, fmt.Errorf("new magicHeader key: %s; value: %s-%s; %w", key, minStr, maxStr, err)
|
||||
}
|
||||
|
||||
return magicHeader, nil
|
||||
}
|
||||
|
||||
type MagicHeaders struct {
|
||||
Values []MagicHeader
|
||||
randomGenerator RandomNumberGenerator[uint32]
|
||||
}
|
||||
|
||||
func NewMagicHeaders(headerValues []MagicHeader) (MagicHeaders, error) {
|
||||
if len(headerValues) != 4 {
|
||||
return MagicHeaders{}, fmt.Errorf("all header types should be included: %v", headerValues)
|
||||
}
|
||||
|
||||
sortedMagicHeaders := slices.SortedFunc(slices.Values(headerValues), func(lhs MagicHeader, rhs MagicHeader) int {
|
||||
return cmp.Compare(lhs.Min, rhs.Min)
|
||||
})
|
||||
|
||||
for i := range 3 {
|
||||
if sortedMagicHeaders[i].Max >= sortedMagicHeaders[i+1].Min {
|
||||
return MagicHeaders{}, fmt.Errorf(
|
||||
"magic headers shouldn't overlap; %v > %v",
|
||||
sortedMagicHeaders[i].Max,
|
||||
sortedMagicHeaders[i+1].Min,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return MagicHeaders{Values: headerValues, randomGenerator: NewPRNG[uint32]()}, nil
|
||||
}
|
||||
|
||||
func (mh *MagicHeaders) Get(defaultMsgType uint32) (uint32, error) {
|
||||
if defaultMsgType == 0 || defaultMsgType > 4 {
|
||||
return 0, fmt.Errorf("invalid msg type: %d", defaultMsgType)
|
||||
}
|
||||
|
||||
return mh.randomGenerator.RandomSizeInRange(mh.Values[defaultMsgType-1].Min, mh.Values[defaultMsgType-1].Max), nil
|
||||
}
|
||||
@@ -1,488 +0,0 @@
|
||||
package awg
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewMagicHeaderSameValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value uint32
|
||||
expected MagicHeader
|
||||
}{
|
||||
{
|
||||
name: "zero value",
|
||||
value: 0,
|
||||
expected: MagicHeader{Min: 0, Max: 0},
|
||||
},
|
||||
{
|
||||
name: "small value",
|
||||
value: 1,
|
||||
expected: MagicHeader{Min: 1, Max: 1},
|
||||
},
|
||||
{
|
||||
name: "large value",
|
||||
value: 4294967295, // max uint32
|
||||
expected: MagicHeader{Min: 4294967295, Max: 4294967295},
|
||||
},
|
||||
{
|
||||
name: "medium value",
|
||||
value: 1000,
|
||||
expected: MagicHeader{Min: 1000, Max: 1000},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := NewMagicHeaderSameValue(tt.value)
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMagicHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
min uint32
|
||||
max uint32
|
||||
expected MagicHeader
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid range",
|
||||
min: 1,
|
||||
max: 10,
|
||||
expected: MagicHeader{Min: 1, Max: 10},
|
||||
},
|
||||
{
|
||||
name: "equal values",
|
||||
min: 5,
|
||||
max: 5,
|
||||
expected: MagicHeader{Min: 5, Max: 5},
|
||||
},
|
||||
{
|
||||
name: "zero range",
|
||||
min: 0,
|
||||
max: 0,
|
||||
expected: MagicHeader{Min: 0, Max: 0},
|
||||
},
|
||||
{
|
||||
name: "max uint32 range",
|
||||
min: 4294967294,
|
||||
max: 4294967295,
|
||||
expected: MagicHeader{Min: 4294967294, Max: 4294967295},
|
||||
},
|
||||
{
|
||||
name: "min greater than max",
|
||||
min: 10,
|
||||
max: 5,
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "min (10) cannot be greater than max (5)",
|
||||
},
|
||||
{
|
||||
name: "large min greater than max",
|
||||
min: 4294967295,
|
||||
max: 1,
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "min (4294967295) cannot be greater than max (1)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := NewMagicHeader(tt.min, tt.max)
|
||||
|
||||
if tt.errorMsg != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errorMsg)
|
||||
require.Equal(t, MagicHeader{}, result)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMagicHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
expected MagicHeader
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "single value",
|
||||
key: "header1",
|
||||
value: "100",
|
||||
expected: MagicHeader{Min: 100, Max: 100},
|
||||
},
|
||||
{
|
||||
name: "valid range",
|
||||
key: "header2",
|
||||
value: "10-20",
|
||||
expected: MagicHeader{Min: 10, Max: 20},
|
||||
},
|
||||
{
|
||||
name: "zero single value",
|
||||
key: "header3",
|
||||
value: "0",
|
||||
expected: MagicHeader{Min: 0, Max: 0},
|
||||
},
|
||||
{
|
||||
name: "zero range",
|
||||
key: "header4",
|
||||
value: "0-0",
|
||||
expected: MagicHeader{Min: 0, Max: 0},
|
||||
},
|
||||
{
|
||||
name: "max uint32 single",
|
||||
key: "header5",
|
||||
value: "4294967295",
|
||||
expected: MagicHeader{Min: 4294967295, Max: 4294967295},
|
||||
},
|
||||
{
|
||||
name: "max uint32 range",
|
||||
key: "header6",
|
||||
value: "4294967294-4294967295",
|
||||
expected: MagicHeader{Min: 4294967294, Max: 4294967295},
|
||||
},
|
||||
{
|
||||
name: "invalid single value - not number",
|
||||
key: "header7",
|
||||
value: "abc",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "parse key: header7; value: abc;",
|
||||
},
|
||||
{
|
||||
name: "invalid single value - negative",
|
||||
key: "header8",
|
||||
value: "-5",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "invalid value for key: header8; value: -5;",
|
||||
},
|
||||
{
|
||||
name: "invalid single value - too large",
|
||||
key: "header9",
|
||||
value: "4294967296",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "parse key: header9; value: 4294967296;",
|
||||
},
|
||||
{
|
||||
name: "invalid range - min not number",
|
||||
key: "header10",
|
||||
value: "abc-10",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "parse min key: header10; value: abc;",
|
||||
},
|
||||
{
|
||||
name: "invalid range - max not number",
|
||||
key: "header11",
|
||||
value: "10-abc",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "parse max key: header11; value: abc;",
|
||||
},
|
||||
{
|
||||
name: "invalid range - min greater than max",
|
||||
key: "header12",
|
||||
value: "20-10",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "new magicHeader key: header12; value: 20-10;",
|
||||
},
|
||||
{
|
||||
name: "invalid range - too many parts",
|
||||
key: "header13",
|
||||
value: "10-20-30",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "parse key: header13; value: 10-20-30;",
|
||||
},
|
||||
{
|
||||
name: "empty value",
|
||||
key: "header14",
|
||||
value: "",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "parse key: header14; value: ;",
|
||||
},
|
||||
{
|
||||
name: "hyphen only",
|
||||
key: "header15",
|
||||
value: "-",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "invalid value for key: header15; value: -;",
|
||||
},
|
||||
{
|
||||
name: "empty min",
|
||||
key: "header16",
|
||||
value: "-10",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "invalid value for key: header16; value: -10;",
|
||||
},
|
||||
{
|
||||
name: "empty max",
|
||||
key: "header17",
|
||||
value: "10-",
|
||||
expected: MagicHeader{},
|
||||
errorMsg: "invalid value for key: header17; value: 10-;",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := ParseMagicHeader(tt.key, tt.value)
|
||||
|
||||
if tt.errorMsg != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errorMsg)
|
||||
require.Equal(t, MagicHeader{}, result)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMagicHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
magicHeaders []MagicHeader
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid non-overlapping headers",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 1, Max: 10},
|
||||
{Min: 11, Max: 20},
|
||||
{Min: 21, Max: 30},
|
||||
{Min: 31, Max: 40},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid adjacent headers",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 1, Max: 1},
|
||||
{Min: 2, Max: 2},
|
||||
{Min: 3, Max: 3},
|
||||
{Min: 4, Max: 4},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid zero-based headers",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 0, Max: 0},
|
||||
{Min: 1, Max: 1},
|
||||
{Min: 2, Max: 2},
|
||||
{Min: 3, Max: 3},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid large value headers",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 4294967290, Max: 4294967291},
|
||||
{Min: 4294967292, Max: 4294967293},
|
||||
{Min: 4294967294, Max: 4294967294},
|
||||
{Min: 4294967295, Max: 4294967295},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "too few headers",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 1, Max: 10},
|
||||
{Min: 11, Max: 20},
|
||||
{Min: 21, Max: 30},
|
||||
},
|
||||
errorMsg: "all header types should be included:",
|
||||
},
|
||||
{
|
||||
name: "too many headers",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 1, Max: 10},
|
||||
{Min: 11, Max: 20},
|
||||
{Min: 21, Max: 30},
|
||||
{Min: 31, Max: 40},
|
||||
{Min: 41, Max: 50},
|
||||
},
|
||||
errorMsg: "all header types should be included:",
|
||||
},
|
||||
{
|
||||
name: "empty headers",
|
||||
magicHeaders: []MagicHeader{},
|
||||
errorMsg: "all header types should be included:",
|
||||
},
|
||||
{
|
||||
name: "overlapping headers",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 1, Max: 15},
|
||||
{Min: 10, Max: 20},
|
||||
{Min: 25, Max: 30},
|
||||
{Min: 35, Max: 40},
|
||||
},
|
||||
errorMsg: "magic headers shouldn't overlap;",
|
||||
},
|
||||
{
|
||||
name: "overlapping headers at limit-first",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 1, Max: 10},
|
||||
{Min: 10, Max: 20},
|
||||
{Min: 25, Max: 30},
|
||||
{Min: 35, Max: 40},
|
||||
},
|
||||
errorMsg: "magic headers shouldn't overlap;",
|
||||
},
|
||||
{
|
||||
name: "overlapping headers at limit-second",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 1, Max: 10},
|
||||
{Min: 15, Max: 25},
|
||||
{Min: 25, Max: 30},
|
||||
{Min: 35, Max: 40},
|
||||
},
|
||||
errorMsg: "magic headers shouldn't overlap;",
|
||||
},
|
||||
{
|
||||
name: "overlapping headers at limit-third",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 1, Max: 10},
|
||||
{Min: 15, Max: 25},
|
||||
{Min: 30, Max: 35},
|
||||
{Min: 35, Max: 40},
|
||||
},
|
||||
errorMsg: "magic headers shouldn't overlap;",
|
||||
},
|
||||
{
|
||||
name: "identical ranges",
|
||||
magicHeaders: []MagicHeader{
|
||||
{Min: 10, Max: 20},
|
||||
{Min: 10, Max: 20},
|
||||
{Min: 25, Max: 30},
|
||||
{Min: 35, Max: 40},
|
||||
},
|
||||
errorMsg: "magic headers shouldn't overlap;",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result, err := NewMagicHeaders(tt.magicHeaders)
|
||||
|
||||
if tt.errorMsg != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errorMsg)
|
||||
require.Equal(t, MagicHeaders{}, result)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.magicHeaders, result.Values)
|
||||
require.NotNil(t, result.randomGenerator)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Mock PRNG for testing
|
||||
type mockPRNG struct {
|
||||
returnValue uint32
|
||||
}
|
||||
|
||||
func (m *mockPRNG) RandomSizeInRange(min, max uint32) uint32 {
|
||||
return m.returnValue
|
||||
}
|
||||
|
||||
func (m *mockPRNG) Get() uint64 {
|
||||
return 0
|
||||
}
|
||||
func (m *mockPRNG) ReadSize(size int) []byte {
|
||||
return make([]byte, size)
|
||||
}
|
||||
|
||||
func TestMagicHeaders_Get(t *testing.T) {
|
||||
// Create test headers
|
||||
headers := []MagicHeader{
|
||||
{Min: 1, Max: 10},
|
||||
{Min: 11, Max: 20},
|
||||
{Min: 21, Max: 30},
|
||||
{Min: 31, Max: 40},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
defaultMsgType uint32
|
||||
mockValue uint32
|
||||
expectedValue uint32
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid type 1",
|
||||
defaultMsgType: 1,
|
||||
mockValue: 5,
|
||||
expectedValue: 5,
|
||||
},
|
||||
{
|
||||
name: "valid type 2",
|
||||
defaultMsgType: 2,
|
||||
mockValue: 15,
|
||||
expectedValue: 15,
|
||||
},
|
||||
{
|
||||
name: "valid type 3",
|
||||
defaultMsgType: 3,
|
||||
mockValue: 25,
|
||||
expectedValue: 25,
|
||||
},
|
||||
{
|
||||
name: "valid type 4",
|
||||
defaultMsgType: 4,
|
||||
mockValue: 35,
|
||||
expectedValue: 35,
|
||||
},
|
||||
{
|
||||
name: "invalid type 0",
|
||||
defaultMsgType: 0,
|
||||
mockValue: 0,
|
||||
expectedValue: 0,
|
||||
errorMsg: "invalid msg type: 0",
|
||||
},
|
||||
{
|
||||
name: "invalid type 5",
|
||||
defaultMsgType: 5,
|
||||
mockValue: 0,
|
||||
expectedValue: 0,
|
||||
errorMsg: "invalid msg type: 5",
|
||||
},
|
||||
{
|
||||
name: "invalid type max uint32",
|
||||
defaultMsgType: 4294967295,
|
||||
mockValue: 0,
|
||||
expectedValue: 0,
|
||||
errorMsg: "invalid msg type: 4294967295",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Create a new instance with mock PRNG for each test
|
||||
testMagicHeaders := MagicHeaders{
|
||||
Values: headers,
|
||||
randomGenerator: &mockPRNG{returnValue: tt.mockValue},
|
||||
}
|
||||
|
||||
result, err := testMagicHeaders.Get(tt.defaultMsgType)
|
||||
|
||||
if tt.errorMsg != "" {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errorMsg)
|
||||
require.Equal(t, uint32(0), result)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.expectedValue, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
package awg
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
v2 "math/rand/v2"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type RandomNumberGenerator[T constraints.Integer] interface {
|
||||
RandomSizeInRange(min, max T) T
|
||||
Get() uint64
|
||||
ReadSize(size int) []byte
|
||||
}
|
||||
|
||||
type PRNG[T constraints.Integer] struct {
|
||||
cha8Rand *v2.ChaCha8
|
||||
}
|
||||
|
||||
func NewPRNG[T constraints.Integer]() PRNG[T] {
|
||||
buf := make([]byte, 32)
|
||||
_, _ = crand.Read(buf)
|
||||
|
||||
return PRNG[T]{
|
||||
cha8Rand: v2.NewChaCha8([32]byte(buf)),
|
||||
}
|
||||
}
|
||||
|
||||
func (p PRNG[T]) RandomSizeInRange(min, max T) T {
|
||||
if min > max {
|
||||
panic("min must be less than max")
|
||||
}
|
||||
|
||||
if min == max {
|
||||
return min
|
||||
}
|
||||
|
||||
return T(p.Get()%uint64(max-min)) + min
|
||||
}
|
||||
|
||||
func (p PRNG[T]) Get() uint64 {
|
||||
return p.cha8Rand.Uint64()
|
||||
}
|
||||
|
||||
func (p PRNG[T]) ReadSize(size int) []byte {
|
||||
// TODO: use a memory pool to allocate
|
||||
buf := make([]byte, size)
|
||||
_, _ = p.cha8Rand.Read(buf)
|
||||
return buf
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
package awg
|
||||
|
||||
import (
|
||||
"github.com/tevino/abool"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// TODO: atomic?/ and better way to use this
|
||||
var PacketCounter *atomic.Uint64 = atomic.NewUint64(0)
|
||||
|
||||
// TODO
|
||||
var WaitResponse = struct {
|
||||
Channel chan struct{}
|
||||
ShouldWait *abool.AtomicBool
|
||||
}{
|
||||
make(chan struct{}, 1),
|
||||
abool.New(),
|
||||
}
|
||||
|
||||
type SpecialHandshakeHandler struct {
|
||||
SpecialJunk TagJunkPacketGenerators
|
||||
|
||||
IsSet bool
|
||||
}
|
||||
|
||||
func (handler *SpecialHandshakeHandler) Validate() error {
|
||||
return handler.SpecialJunk.Validate()
|
||||
}
|
||||
|
||||
func (handler *SpecialHandshakeHandler) GenerateSpecialJunk() [][]byte {
|
||||
if !handler.SpecialJunk.IsDefined() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return handler.SpecialJunk.GeneratePackets()
|
||||
}
|
||||
@@ -1,229 +0,0 @@
|
||||
package awg
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
v2 "math/rand/v2"
|
||||
// "go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type Generator interface {
|
||||
Generate() []byte
|
||||
Size() int
|
||||
}
|
||||
|
||||
type newGenerator func(string) (Generator, error)
|
||||
|
||||
type BytesGenerator struct {
|
||||
value []byte
|
||||
size int
|
||||
}
|
||||
|
||||
func (bg *BytesGenerator) Generate() []byte {
|
||||
return bg.value
|
||||
}
|
||||
|
||||
func (bg *BytesGenerator) Size() int {
|
||||
return bg.size
|
||||
}
|
||||
|
||||
func newBytesGenerator(param string) (Generator, error) {
|
||||
hasPrefix := strings.HasPrefix(param, "0x") || strings.HasPrefix(param, "0X")
|
||||
if !hasPrefix {
|
||||
return nil, fmt.Errorf("not correct hex: %s", param)
|
||||
}
|
||||
|
||||
hex, err := hexToBytes(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hexToBytes: %w", err)
|
||||
}
|
||||
|
||||
return &BytesGenerator{value: hex, size: len(hex)}, nil
|
||||
}
|
||||
|
||||
func hexToBytes(hexStr string) ([]byte, error) {
|
||||
hexStr = strings.TrimPrefix(hexStr, "0x")
|
||||
hexStr = strings.TrimPrefix(hexStr, "0X")
|
||||
|
||||
// Ensure even length (pad with leading zero if needed)
|
||||
if len(hexStr)%2 != 0 {
|
||||
hexStr = "0" + hexStr
|
||||
}
|
||||
|
||||
return hex.DecodeString(hexStr)
|
||||
}
|
||||
|
||||
type randomGeneratorBase struct {
|
||||
cha8Rand *v2.ChaCha8
|
||||
size int
|
||||
}
|
||||
|
||||
func newRandomGeneratorBase(param string) (*randomGeneratorBase, error) {
|
||||
size, err := strconv.Atoi(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse int: %w", err)
|
||||
}
|
||||
|
||||
if size > 1000 {
|
||||
return nil, fmt.Errorf("size must be less than 1000")
|
||||
}
|
||||
|
||||
buf := make([]byte, 32)
|
||||
_, err = crand.Read(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("crand read: %w", err)
|
||||
}
|
||||
|
||||
return &randomGeneratorBase{
|
||||
cha8Rand: v2.NewChaCha8([32]byte(buf)),
|
||||
size: size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (rpg *randomGeneratorBase) generate() []byte {
|
||||
junk := make([]byte, rpg.size)
|
||||
rpg.cha8Rand.Read(junk)
|
||||
return junk
|
||||
}
|
||||
|
||||
func (rpg *randomGeneratorBase) Size() int {
|
||||
return rpg.size
|
||||
}
|
||||
|
||||
type RandomBytesGenerator struct {
|
||||
*randomGeneratorBase
|
||||
}
|
||||
|
||||
func newRandomBytesGenerator(param string) (Generator, error) {
|
||||
rpgBase, err := newRandomGeneratorBase(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new random bytes generator: %w", err)
|
||||
}
|
||||
|
||||
return &RandomBytesGenerator{randomGeneratorBase: rpgBase}, nil
|
||||
}
|
||||
|
||||
func (rpg *RandomBytesGenerator) Generate() []byte {
|
||||
return rpg.generate()
|
||||
}
|
||||
|
||||
const alphanumericChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
type RandomASCIIGenerator struct {
|
||||
*randomGeneratorBase
|
||||
}
|
||||
|
||||
func newRandomASCIIGenerator(param string) (Generator, error) {
|
||||
rpgBase, err := newRandomGeneratorBase(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new random ascii generator: %w", err)
|
||||
}
|
||||
|
||||
return &RandomASCIIGenerator{randomGeneratorBase: rpgBase}, nil
|
||||
}
|
||||
|
||||
func (rpg *RandomASCIIGenerator) Generate() []byte {
|
||||
junk := rpg.generate()
|
||||
|
||||
result := make([]byte, rpg.size)
|
||||
for i, b := range junk {
|
||||
result[i] = alphanumericChars[b%byte(len(alphanumericChars))]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type RandomDigitGenerator struct {
|
||||
*randomGeneratorBase
|
||||
}
|
||||
|
||||
func newRandomDigitGenerator(param string) (Generator, error) {
|
||||
rpgBase, err := newRandomGeneratorBase(param)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new random digit generator: %w", err)
|
||||
}
|
||||
|
||||
return &RandomDigitGenerator{randomGeneratorBase: rpgBase}, nil
|
||||
}
|
||||
|
||||
func (rpg *RandomDigitGenerator) Generate() []byte {
|
||||
junk := rpg.generate()
|
||||
|
||||
result := make([]byte, rpg.size)
|
||||
for i, b := range junk {
|
||||
result[i] = '0' + (b % 10) // Convert to digit character
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type TimestampGenerator struct {
|
||||
}
|
||||
|
||||
func (tg *TimestampGenerator) Generate() []byte {
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf, uint64(time.Now().Unix()))
|
||||
return buf
|
||||
}
|
||||
|
||||
func (tg *TimestampGenerator) Size() int {
|
||||
return 8
|
||||
}
|
||||
|
||||
func newTimestampGenerator(param string) (Generator, error) {
|
||||
if len(param) != 0 {
|
||||
return nil, fmt.Errorf("timestamp param needs to be empty: %s", param)
|
||||
}
|
||||
|
||||
return &TimestampGenerator{}, nil
|
||||
}
|
||||
|
||||
type PacketCounterGenerator struct {
|
||||
}
|
||||
|
||||
func (c *PacketCounterGenerator) Generate() []byte {
|
||||
buf := make([]byte, 8)
|
||||
// TODO: better way to handle counter tag
|
||||
binary.BigEndian.PutUint64(buf, PacketCounter.Load())
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *PacketCounterGenerator) Size() int {
|
||||
return 8
|
||||
}
|
||||
|
||||
func newPacketCounterGenerator(param string) (Generator, error) {
|
||||
if len(param) != 0 {
|
||||
return nil, fmt.Errorf("packet counter param needs to be empty: %s", param)
|
||||
}
|
||||
|
||||
return &PacketCounterGenerator{}, nil
|
||||
}
|
||||
|
||||
type WaitResponseGenerator struct {
|
||||
}
|
||||
|
||||
func (c *WaitResponseGenerator) Generate() []byte {
|
||||
WaitResponse.ShouldWait.Set()
|
||||
<-WaitResponse.Channel
|
||||
WaitResponse.ShouldWait.UnSet()
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
func (c *WaitResponseGenerator) Size() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func newWaitResponseGenerator(param string) (Generator, error) {
|
||||
if len(param) != 0 {
|
||||
return nil, fmt.Errorf("wait response param needs to be empty: %s", param)
|
||||
}
|
||||
|
||||
return &WaitResponseGenerator{}, nil
|
||||
}
|
||||
@@ -1,321 +0,0 @@
|
||||
package awg
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewBytesGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want []byte
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
param: "",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "wrong start",
|
||||
args: args{
|
||||
param: "123456",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "not only hex value with X",
|
||||
args: args{
|
||||
param: "0X12345q",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "not only hex value with x",
|
||||
args: args{
|
||||
param: "0x12345q",
|
||||
},
|
||||
wantErr: fmt.Errorf("not correct hex"),
|
||||
},
|
||||
{
|
||||
name: "valid hex",
|
||||
args: args{
|
||||
param: "0xf6ab3267fa",
|
||||
},
|
||||
want: []byte{0xf6, 0xab, 0x32, 0x67, 0xfa},
|
||||
},
|
||||
{
|
||||
name: "valid hex with odd length",
|
||||
args: args{
|
||||
param: "0xfab3267fa",
|
||||
},
|
||||
want: []byte{0xf, 0xab, 0x32, 0x67, 0xfa},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := newBytesGenerator(tt.args.param)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, got)
|
||||
|
||||
gotValues := got.Generate()
|
||||
require.Equal(t, tt.want, gotValues)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRandomBytesGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
param: "",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "not an int",
|
||||
args: args{
|
||||
param: "x",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "too large",
|
||||
args: args{
|
||||
param: "1001",
|
||||
},
|
||||
wantErr: fmt.Errorf("random packet size must be less than 1000"),
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
param: "12",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := newRandomBytesGenerator(tt.args.param)
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, got)
|
||||
first := got.Generate()
|
||||
|
||||
second := got.Generate()
|
||||
require.NotEqual(t, first, second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRandomASCIIGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
param: "",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "not an int",
|
||||
args: args{
|
||||
param: "x",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "too large",
|
||||
args: args{
|
||||
param: "1001",
|
||||
},
|
||||
wantErr: fmt.Errorf("random packet size must be less than 1000"),
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
param: "12",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := newRandomASCIIGenerator(tt.args.param)
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, got)
|
||||
first := got.Generate()
|
||||
|
||||
second := got.Generate()
|
||||
require.NotEqual(t, first, second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRandomDigitGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type args struct {
|
||||
param string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
args: args{
|
||||
param: "",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "not an int",
|
||||
args: args{
|
||||
param: "x",
|
||||
},
|
||||
wantErr: fmt.Errorf("parse int"),
|
||||
},
|
||||
{
|
||||
name: "too large",
|
||||
args: args{
|
||||
param: "1001",
|
||||
},
|
||||
wantErr: fmt.Errorf("random packet size must be less than 1000"),
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
param: "12",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, err := newRandomDigitGenerator(tt.args.param)
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
|
||||
require.Nil(t, err)
|
||||
require.NotNil(t, got)
|
||||
first := got.Generate()
|
||||
|
||||
second := got.Generate()
|
||||
require.NotEqual(t, first, second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacketCounterGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
param string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid empty param",
|
||||
param: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid non-empty param",
|
||||
param: "anything",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gen, err := newPacketCounterGenerator(tc.param)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 8, gen.Size())
|
||||
|
||||
// Reset counter to known value for test
|
||||
initialCount := uint64(42)
|
||||
PacketCounter.Store(initialCount)
|
||||
|
||||
output := gen.Generate()
|
||||
require.Equal(t, 8, len(output))
|
||||
|
||||
// Verify counter value in output
|
||||
counterValue := binary.BigEndian.Uint64(output)
|
||||
require.Equal(t, initialCount, counterValue)
|
||||
|
||||
// Increment counter and verify change
|
||||
PacketCounter.Add(1)
|
||||
output = gen.Generate()
|
||||
counterValue = binary.BigEndian.Uint64(output)
|
||||
require.Equal(t, initialCount+1, counterValue)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
package awg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type TagJunkPacketGenerator struct {
|
||||
name string
|
||||
tagValue string
|
||||
|
||||
packetSize int
|
||||
generators []Generator
|
||||
}
|
||||
|
||||
func newTagJunkPacketGenerator(name, tagValue string, size int) TagJunkPacketGenerator {
|
||||
return TagJunkPacketGenerator{
|
||||
name: name,
|
||||
tagValue: tagValue,
|
||||
generators: make([]Generator, 0, size),
|
||||
}
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) append(generator Generator) {
|
||||
tg.generators = append(tg.generators, generator)
|
||||
tg.packetSize += generator.Size()
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) generatePacket() []byte {
|
||||
packet := make([]byte, 0, tg.packetSize)
|
||||
for _, generator := range tg.generators {
|
||||
packet = append(packet, generator.Generate()...)
|
||||
}
|
||||
|
||||
return packet
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) Name() string {
|
||||
return tg.name
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) nameIndex() (int, error) {
|
||||
if len(tg.name) != 2 {
|
||||
return 0, fmt.Errorf("name must be 2 character long: %s", tg.name)
|
||||
}
|
||||
|
||||
index, err := strconv.Atoi(tg.name[1:2])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("name 2 char should be an int %w", err)
|
||||
}
|
||||
return index, nil
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields {
|
||||
return IpcFields{
|
||||
Key: tg.name,
|
||||
Value: tg.tagValue,
|
||||
}
|
||||
}
|
||||
@@ -1,210 +0,0 @@
|
||||
package awg
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg/internal"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewTagJunkGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
genName string
|
||||
size int
|
||||
expected TagJunkPacketGenerator
|
||||
}{
|
||||
{
|
||||
name: "Create new generator with empty name",
|
||||
genName: "",
|
||||
size: 0,
|
||||
expected: TagJunkPacketGenerator{
|
||||
name: "",
|
||||
packetSize: 0,
|
||||
generators: make([]Generator, 0),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create new generator with valid name",
|
||||
genName: "T1",
|
||||
size: 0,
|
||||
expected: TagJunkPacketGenerator{
|
||||
name: "T1",
|
||||
packetSize: 0,
|
||||
generators: make([]Generator, 0),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Create new generator with non-zero size",
|
||||
genName: "T2",
|
||||
size: 5,
|
||||
expected: TagJunkPacketGenerator{
|
||||
name: "T2",
|
||||
packetSize: 0,
|
||||
generators: make([]Generator, 5),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
result := newTagJunkPacketGenerator(tc.genName, "", tc.size)
|
||||
require.Equal(t, tc.expected.name, result.name)
|
||||
require.Equal(t, tc.expected.packetSize, result.packetSize)
|
||||
require.Equal(t, cap(result.generators), len(tc.expected.generators))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorAppend(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
initialState TagJunkPacketGenerator
|
||||
mockSize int
|
||||
expectedLength int
|
||||
expectedSize int
|
||||
}{
|
||||
{
|
||||
name: "Append to empty generator",
|
||||
initialState: newTagJunkPacketGenerator("T1", "", 0),
|
||||
mockSize: 5,
|
||||
expectedLength: 1,
|
||||
expectedSize: 5,
|
||||
},
|
||||
{
|
||||
name: "Append to non-empty generator",
|
||||
initialState: TagJunkPacketGenerator{
|
||||
name: "T2",
|
||||
packetSize: 10,
|
||||
generators: make([]Generator, 2),
|
||||
},
|
||||
mockSize: 7,
|
||||
expectedLength: 3, // 2 existing + 1 new
|
||||
expectedSize: 17, // 10 + 7
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tg := tc.initialState
|
||||
mockGen := internal.NewMockGenerator(tc.mockSize)
|
||||
|
||||
tg.append(mockGen)
|
||||
|
||||
require.Equal(t, tc.expectedLength, len(tg.generators))
|
||||
require.Equal(t, tc.expectedSize, tg.packetSize)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorGenerate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create mock generators for testing
|
||||
mockGen1 := internal.NewMockByteGenerator([]byte{0x01, 0x02})
|
||||
mockGen2 := internal.NewMockByteGenerator([]byte{0x03, 0x04, 0x05})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupGenerator func() TagJunkPacketGenerator
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
name: "Generate with empty generators",
|
||||
setupGenerator: func() TagJunkPacketGenerator {
|
||||
return newTagJunkPacketGenerator("T1", "", 0)
|
||||
},
|
||||
expected: []byte{},
|
||||
},
|
||||
{
|
||||
name: "Generate with single generator",
|
||||
setupGenerator: func() TagJunkPacketGenerator {
|
||||
tg := newTagJunkPacketGenerator("T2", "", 0)
|
||||
tg.append(mockGen1)
|
||||
return tg
|
||||
},
|
||||
expected: []byte{0x01, 0x02},
|
||||
},
|
||||
{
|
||||
name: "Generate with multiple generators",
|
||||
setupGenerator: func() TagJunkPacketGenerator {
|
||||
tg := newTagJunkPacketGenerator("T3", "", 0)
|
||||
tg.append(mockGen1)
|
||||
tg.append(mockGen2)
|
||||
return tg
|
||||
},
|
||||
expected: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tg := tc.setupGenerator()
|
||||
result := tg.generatePacket()
|
||||
|
||||
require.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorNameIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
generatorName string
|
||||
expectedIndex int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid name with digit",
|
||||
generatorName: "T5",
|
||||
expectedIndex: 5,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid name - too short",
|
||||
generatorName: "T",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid name - too long",
|
||||
generatorName: "T55",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid name - non-digit second character",
|
||||
generatorName: "TX",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tg := TagJunkPacketGenerator{name: tc.generatorName}
|
||||
index, err := tg.nameIndex()
|
||||
|
||||
if tc.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.expectedIndex, index)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
package awg
|
||||
|
||||
import "fmt"
|
||||
|
||||
type TagJunkPacketGenerators struct {
|
||||
tagGenerators []TagJunkPacketGenerator
|
||||
length int
|
||||
DefaultJunkCount int // Jc
|
||||
}
|
||||
|
||||
func (generators *TagJunkPacketGenerators) AppendGenerator(
|
||||
generator TagJunkPacketGenerator,
|
||||
) {
|
||||
generators.tagGenerators = append(generators.tagGenerators, generator)
|
||||
generators.length++
|
||||
}
|
||||
|
||||
func (generators *TagJunkPacketGenerators) IsDefined() bool {
|
||||
return len(generators.tagGenerators) > 0
|
||||
}
|
||||
|
||||
// validate that packets were defined consecutively
|
||||
func (generators *TagJunkPacketGenerators) Validate() error {
|
||||
seen := make([]bool, len(generators.tagGenerators))
|
||||
for _, generator := range generators.tagGenerators {
|
||||
index, err := generator.nameIndex()
|
||||
if index > len(generators.tagGenerators) {
|
||||
return fmt.Errorf("junk packet index should be consecutive")
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("name index: %w", err)
|
||||
} else {
|
||||
seen[index-1] = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, found := range seen {
|
||||
if !found {
|
||||
return fmt.Errorf("junk packet index should be consecutive")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (generators *TagJunkPacketGenerators) GeneratePackets() [][]byte {
|
||||
var rv = make([][]byte, 0, generators.length+generators.DefaultJunkCount)
|
||||
|
||||
for i, tagGenerator := range generators.tagGenerators {
|
||||
rv = append(rv, make([]byte, tagGenerator.packetSize))
|
||||
copy(rv[i], tagGenerator.generatePacket())
|
||||
PacketCounter.Inc()
|
||||
}
|
||||
PacketCounter.Add(uint64(generators.DefaultJunkCount))
|
||||
|
||||
return rv
|
||||
}
|
||||
|
||||
func (tg *TagJunkPacketGenerators) IpcGetFields() []IpcFields {
|
||||
rv := make([]IpcFields, 0, len(tg.tagGenerators))
|
||||
for _, generator := range tg.tagGenerators {
|
||||
rv = append(rv, generator.IpcGetFields())
|
||||
}
|
||||
|
||||
return rv
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
package awg
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg/internal"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTagJunkGeneratorHandlerAppendGenerator(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
generator TagJunkPacketGenerator
|
||||
}{
|
||||
{
|
||||
name: "append single generator",
|
||||
generator: newTagJunkPacketGenerator("t1", "", 10),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
generators := &TagJunkPacketGenerators{}
|
||||
|
||||
// Initial length should be 0
|
||||
require.Equal(t, 0, generators.length)
|
||||
require.Empty(t, generators.tagGenerators)
|
||||
|
||||
// After append, length should be 1 and generator should be added
|
||||
generators.AppendGenerator(tt.generator)
|
||||
require.Equal(t, 1, generators.length)
|
||||
require.Len(t, generators.tagGenerators, 1)
|
||||
require.Equal(t, tt.generator, generators.tagGenerators[0])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
generators []TagJunkPacketGenerator
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "bad start",
|
||||
generators: []TagJunkPacketGenerator{
|
||||
newTagJunkPacketGenerator("t3", "", 10),
|
||||
newTagJunkPacketGenerator("t4", "", 10),
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "junk packet index should be consecutive",
|
||||
},
|
||||
{
|
||||
name: "non-consecutive indices",
|
||||
generators: []TagJunkPacketGenerator{
|
||||
newTagJunkPacketGenerator("t1", "", 10),
|
||||
newTagJunkPacketGenerator("t3", "", 10), // Missing t2
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "junk packet index should be consecutive",
|
||||
},
|
||||
{
|
||||
name: "consecutive indices",
|
||||
generators: []TagJunkPacketGenerator{
|
||||
newTagJunkPacketGenerator("t1", "", 10),
|
||||
newTagJunkPacketGenerator("t2", "", 10),
|
||||
newTagJunkPacketGenerator("t3", "", 10),
|
||||
newTagJunkPacketGenerator("t4", "", 10),
|
||||
newTagJunkPacketGenerator("t5", "", 10),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nameIndex error",
|
||||
generators: []TagJunkPacketGenerator{
|
||||
newTagJunkPacketGenerator("error", "", 10),
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "name must be 2 character long",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
generators := &TagJunkPacketGenerators{}
|
||||
for _, gen := range tt.generators {
|
||||
generators.AppendGenerator(gen)
|
||||
}
|
||||
|
||||
err := generators.Validate()
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errMsg)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagJunkGeneratorHandlerGenerate(t *testing.T) {
|
||||
mockByte1 := []byte{0x01, 0x02}
|
||||
mockByte2 := []byte{0x03, 0x04, 0x05}
|
||||
mockGen1 := internal.NewMockByteGenerator(mockByte1)
|
||||
mockGen2 := internal.NewMockByteGenerator(mockByte2)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupGenerator func() []TagJunkPacketGenerator
|
||||
expected [][]byte
|
||||
}{
|
||||
{
|
||||
name: "generate with no default junk",
|
||||
setupGenerator: func() []TagJunkPacketGenerator {
|
||||
tg1 := newTagJunkPacketGenerator("t1", "", 0)
|
||||
tg1.append(mockGen1)
|
||||
tg1.append(mockGen2)
|
||||
tg2 := newTagJunkPacketGenerator("t2", "", 0)
|
||||
tg2.append(mockGen2)
|
||||
tg2.append(mockGen1)
|
||||
|
||||
return []TagJunkPacketGenerator{tg1, tg2}
|
||||
},
|
||||
expected: [][]byte{
|
||||
append(mockByte1, mockByte2...),
|
||||
append(mockByte2, mockByte1...),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
generators := &TagJunkPacketGenerators{}
|
||||
tagGenerators := tt.setupGenerator()
|
||||
for _, gen := range tagGenerators {
|
||||
generators.AppendGenerator(gen)
|
||||
}
|
||||
|
||||
result := generators.GeneratePackets()
|
||||
require.Equal(t, result, tt.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
package awg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type IpcFields struct{ Key, Value string }
|
||||
|
||||
type EnumTag string
|
||||
|
||||
const (
|
||||
BytesEnumTag EnumTag = "b"
|
||||
CounterEnumTag EnumTag = "c"
|
||||
TimestampEnumTag EnumTag = "t"
|
||||
RandomBytesEnumTag EnumTag = "r"
|
||||
RandomASCIIEnumTag EnumTag = "rc"
|
||||
RandomDigitEnumTag EnumTag = "rd"
|
||||
)
|
||||
|
||||
var generatorCreator = map[EnumTag]newGenerator{
|
||||
BytesEnumTag: newBytesGenerator,
|
||||
CounterEnumTag: newPacketCounterGenerator,
|
||||
TimestampEnumTag: newTimestampGenerator,
|
||||
RandomBytesEnumTag: newRandomBytesGenerator,
|
||||
RandomASCIIEnumTag: newRandomASCIIGenerator,
|
||||
RandomDigitEnumTag: newRandomDigitGenerator,
|
||||
}
|
||||
|
||||
// helper map to determine enumTags are unique
|
||||
var uniqueTags = map[EnumTag]bool{
|
||||
CounterEnumTag: false,
|
||||
TimestampEnumTag: false,
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
Name EnumTag
|
||||
Param string
|
||||
}
|
||||
|
||||
func parseTag(input string) (Tag, error) {
|
||||
// Regular expression to match <tagname optional_param>
|
||||
re := regexp.MustCompile(`([a-zA-Z]+)(?:\s+([^>]+))?>`)
|
||||
|
||||
match := re.FindStringSubmatch(input)
|
||||
tag := Tag{
|
||||
Name: EnumTag(match[1]),
|
||||
}
|
||||
if len(match) > 2 && match[2] != "" {
|
||||
tag.Param = strings.TrimSpace(match[2])
|
||||
}
|
||||
|
||||
return tag, nil
|
||||
}
|
||||
|
||||
func ParseTagJunkGenerator(name, input string) (TagJunkPacketGenerator, error) {
|
||||
inputSlice := strings.Split(input, "<")
|
||||
if len(inputSlice) <= 1 {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("empty input: %s", input)
|
||||
}
|
||||
|
||||
uniqueTagCheck := make(map[EnumTag]bool, len(uniqueTags))
|
||||
maps.Copy(uniqueTagCheck, uniqueTags)
|
||||
|
||||
// skip byproduct of split
|
||||
inputSlice = inputSlice[1:]
|
||||
rv := newTagJunkPacketGenerator(name, input, len(inputSlice))
|
||||
for _, inputParam := range inputSlice {
|
||||
if len(inputParam) <= 1 {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf(
|
||||
"empty tag in input: %s",
|
||||
inputSlice,
|
||||
)
|
||||
} else if strings.Count(inputParam, ">") != 1 {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("ill formated input: %s", input)
|
||||
}
|
||||
|
||||
tag, _ := parseTag(inputParam)
|
||||
creator, ok := generatorCreator[tag.Name]
|
||||
if !ok {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("invalid tag: %s", tag.Name)
|
||||
}
|
||||
if present, ok := uniqueTagCheck[tag.Name]; ok {
|
||||
if present {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf(
|
||||
"tag %s needs to be unique",
|
||||
tag.Name,
|
||||
)
|
||||
}
|
||||
uniqueTagCheck[tag.Name] = true
|
||||
}
|
||||
generator, err := creator(tag.Param)
|
||||
if err != nil {
|
||||
return TagJunkPacketGenerator{}, fmt.Errorf("gen: %w", err)
|
||||
}
|
||||
|
||||
// TODO: handle counter tag
|
||||
// if tag.Name == CounterEnumTag {
|
||||
// packetCounter, ok := generator.(*PacketCounterGenerator)
|
||||
// if !ok {
|
||||
// log.Fatalf("packet counter generator expected, got %T", generator)
|
||||
// }
|
||||
// PacketCounter = packetCounter.counter
|
||||
// }
|
||||
|
||||
rv.append(generator)
|
||||
}
|
||||
|
||||
return rv, nil
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
package awg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
type args struct {
|
||||
name string
|
||||
input string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "invalid name",
|
||||
args: args{name: "apple", input: ""},
|
||||
wantErr: fmt.Errorf("ill formated input"),
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
args: args{name: "i1", input: ""},
|
||||
wantErr: fmt.Errorf("ill formated input"),
|
||||
},
|
||||
{
|
||||
name: "extra >",
|
||||
args: args{name: "i1", input: "<b 0xf6ab3267fa><c>>"},
|
||||
wantErr: fmt.Errorf("ill formated input"),
|
||||
},
|
||||
{
|
||||
name: "extra <",
|
||||
args: args{name: "i1", input: "<<b 0xf6ab3267fa><c>"},
|
||||
wantErr: fmt.Errorf("empty tag in input"),
|
||||
},
|
||||
{
|
||||
name: "empty <>",
|
||||
args: args{name: "i1", input: "<><b 0xf6ab3267fa><c>"},
|
||||
wantErr: fmt.Errorf("empty tag in input"),
|
||||
},
|
||||
{
|
||||
name: "invalid tag",
|
||||
args: args{name: "i1", input: "<q 0xf6ab3267fa>"},
|
||||
wantErr: fmt.Errorf("invalid tag"),
|
||||
},
|
||||
{
|
||||
name: "counter uniqueness violation",
|
||||
args: args{name: "i1", input: "<c><c>"},
|
||||
wantErr: fmt.Errorf("parse tag needs to be unique"),
|
||||
},
|
||||
{
|
||||
name: "timestamp uniqueness violation",
|
||||
args: args{name: "i1", input: "<t><t>"},
|
||||
wantErr: fmt.Errorf("parse tag needs to be unique"),
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{input: "<b 0xf6ab3267fa><c><b 0xf6ab><t><r 10><wt 10>"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := ParseTagJunkGenerator(tt.args.name, tt.args.input)
|
||||
|
||||
// TODO: ErrorAs doesn't work as you think
|
||||
if tt.wantErr != nil {
|
||||
require.ErrorAs(t, err, &tt.wantErr)
|
||||
return
|
||||
}
|
||||
require.Nil(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -99,7 +99,7 @@ func TestCookieMAC1(t *testing.T) {
|
||||
0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d,
|
||||
}
|
||||
generator.AddMacs(msg)
|
||||
reply, err := checker.CreateReply(msg, 1377, src, DefaultMessageCookieReplyType)
|
||||
reply, err := checker.CreateReply(msg, 1377, src, MessageCookieReplyType)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to create cookie reply:", err)
|
||||
}
|
||||
|
||||
425
device/device.go
425
device/device.go
@@ -6,57 +6,17 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
|
||||
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
||||
)
|
||||
|
||||
type Version uint8
|
||||
|
||||
const (
|
||||
VersionDefault Version = iota
|
||||
VersionAwg
|
||||
VersionAwgSpecialHandshake
|
||||
)
|
||||
|
||||
// TODO:
|
||||
type AtomicVersion struct {
|
||||
value atomic.Uint32
|
||||
}
|
||||
|
||||
func NewAtomicVersion(v Version) *AtomicVersion {
|
||||
av := &AtomicVersion{}
|
||||
av.Store(v)
|
||||
return av
|
||||
}
|
||||
|
||||
func (av *AtomicVersion) Load() Version {
|
||||
return Version(av.value.Load())
|
||||
}
|
||||
|
||||
func (av *AtomicVersion) Store(v Version) {
|
||||
av.value.Store(uint32(v))
|
||||
}
|
||||
|
||||
func (av *AtomicVersion) CompareAndSwap(old, new Version) bool {
|
||||
return av.value.CompareAndSwap(uint32(old), uint32(new))
|
||||
}
|
||||
|
||||
func (av *AtomicVersion) Swap(new Version) Version {
|
||||
return Version(av.value.Swap(uint32(new)))
|
||||
}
|
||||
|
||||
type Device struct {
|
||||
state struct {
|
||||
// state holds the device's state. It is accessed atomically.
|
||||
@@ -130,8 +90,27 @@ type Device struct {
|
||||
closed chan struct{}
|
||||
log *Logger
|
||||
|
||||
version Version
|
||||
awg awg.Protocol
|
||||
junk struct {
|
||||
min int
|
||||
max int
|
||||
count int
|
||||
}
|
||||
|
||||
headers struct {
|
||||
init *magicHeader
|
||||
cookie *magicHeader
|
||||
response *magicHeader
|
||||
transport *magicHeader
|
||||
}
|
||||
|
||||
paddings struct {
|
||||
init int
|
||||
response int
|
||||
cookie int
|
||||
transport int
|
||||
}
|
||||
|
||||
ipackets [5]*obfChain
|
||||
}
|
||||
|
||||
// deviceState represents the state of a Device.
|
||||
@@ -342,6 +321,11 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
||||
device.rate.limiter.Init()
|
||||
device.indexTable.Init()
|
||||
|
||||
device.headers.init = &magicHeader{start: MessageInitiationType, end: MessageInitiationType}
|
||||
device.headers.response = &magicHeader{start: MessageResponseType, end: MessageResponseType}
|
||||
device.headers.cookie = &magicHeader{start: MessageCookieReplyType, end: MessageCookieReplyType}
|
||||
device.headers.transport = &magicHeader{start: MessageTransportType, end: MessageTransportType}
|
||||
|
||||
device.PopulatePools()
|
||||
|
||||
// create queues
|
||||
@@ -439,8 +423,6 @@ func (device *Device) Close() {
|
||||
|
||||
device.rate.limiter.Close()
|
||||
|
||||
device.resetProtocol()
|
||||
|
||||
device.log.Verbosef("Device closed")
|
||||
close(device.closed)
|
||||
}
|
||||
@@ -580,358 +562,3 @@ func (device *Device) BindClose() error {
|
||||
device.net.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (device *Device) isAWG() bool {
|
||||
return device.version >= VersionAwg
|
||||
}
|
||||
|
||||
func (device *Device) resetProtocol() {
|
||||
// restore default message type values
|
||||
MessageInitiationType = DefaultMessageInitiationType
|
||||
MessageResponseType = DefaultMessageResponseType
|
||||
MessageCookieReplyType = DefaultMessageCookieReplyType
|
||||
MessageTransportType = DefaultMessageTransportType
|
||||
}
|
||||
|
||||
func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
|
||||
if !tempAwg.Cfg.IsSet && !tempAwg.HandshakeHandler.IsSet {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errs []error
|
||||
|
||||
isAwgOn := false
|
||||
device.awg.Mux.Lock()
|
||||
if tempAwg.Cfg.JunkPacketCount < 0 {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"JunkPacketCount should be non negative",
|
||||
),
|
||||
)
|
||||
}
|
||||
device.awg.Cfg.JunkPacketCount = tempAwg.Cfg.JunkPacketCount
|
||||
if tempAwg.Cfg.JunkPacketCount != 0 {
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
device.awg.Cfg.JunkPacketMinSize = tempAwg.Cfg.JunkPacketMinSize
|
||||
if tempAwg.Cfg.JunkPacketMinSize != 0 {
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
if device.awg.Cfg.JunkPacketCount > 0 &&
|
||||
tempAwg.Cfg.JunkPacketMaxSize == tempAwg.Cfg.JunkPacketMinSize {
|
||||
|
||||
tempAwg.Cfg.JunkPacketMaxSize++ // to make rand gen work
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.JunkPacketMaxSize >= MaxSegmentSize {
|
||||
device.awg.Cfg.JunkPacketMinSize = 0
|
||||
device.awg.Cfg.JunkPacketMaxSize = 1
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
|
||||
tempAwg.Cfg.JunkPacketMaxSize,
|
||||
MaxSegmentSize,
|
||||
))
|
||||
} else if tempAwg.Cfg.JunkPacketMaxSize < tempAwg.Cfg.JunkPacketMinSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"maxSize: %d; should be greater than minSize: %d",
|
||||
tempAwg.Cfg.JunkPacketMaxSize,
|
||||
tempAwg.Cfg.JunkPacketMinSize,
|
||||
))
|
||||
} else {
|
||||
device.awg.Cfg.JunkPacketMaxSize = tempAwg.Cfg.JunkPacketMaxSize
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.JunkPacketMaxSize != 0 {
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
magicHeaders := make([]awg.MagicHeader, 4)
|
||||
|
||||
if len(tempAwg.Cfg.MagicHeaders.Values) != 4 {
|
||||
return ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
"magic headers should have 4 values; got: %d",
|
||||
len(tempAwg.Cfg.MagicHeaders.Values),
|
||||
)
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.MagicHeaders.Values[0].Min > 4 {
|
||||
isAwgOn = true
|
||||
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
|
||||
magicHeaders[0] = tempAwg.Cfg.MagicHeaders.Values[0]
|
||||
|
||||
MessageInitiationType = magicHeaders[0].Min
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default init type")
|
||||
MessageInitiationType = DefaultMessageInitiationType
|
||||
magicHeaders[0] = awg.NewMagicHeaderSameValue(DefaultMessageInitiationType)
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.MagicHeaders.Values[1].Min > 4 {
|
||||
isAwgOn = true
|
||||
|
||||
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
|
||||
magicHeaders[1] = tempAwg.Cfg.MagicHeaders.Values[1]
|
||||
MessageResponseType = magicHeaders[1].Min
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default response type")
|
||||
MessageResponseType = DefaultMessageResponseType
|
||||
magicHeaders[1] = awg.NewMagicHeaderSameValue(DefaultMessageResponseType)
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.MagicHeaders.Values[2].Min > 4 {
|
||||
isAwgOn = true
|
||||
|
||||
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
|
||||
magicHeaders[2] = tempAwg.Cfg.MagicHeaders.Values[2]
|
||||
MessageCookieReplyType = magicHeaders[2].Min
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default underload type")
|
||||
MessageCookieReplyType = DefaultMessageCookieReplyType
|
||||
magicHeaders[2] = awg.NewMagicHeaderSameValue(DefaultMessageCookieReplyType)
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.MagicHeaders.Values[3].Min > 4 {
|
||||
isAwgOn = true
|
||||
|
||||
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
|
||||
magicHeaders[3] = tempAwg.Cfg.MagicHeaders.Values[3]
|
||||
MessageTransportType = magicHeaders[3].Min
|
||||
} else {
|
||||
device.log.Verbosef("UAPI: Using default transport type")
|
||||
MessageTransportType = DefaultMessageTransportType
|
||||
magicHeaders[3] = awg.NewMagicHeaderSameValue(DefaultMessageTransportType)
|
||||
}
|
||||
|
||||
var err error
|
||||
device.awg.Cfg.MagicHeaders, err = awg.NewMagicHeaders(magicHeaders)
|
||||
if err != nil {
|
||||
errs = append(errs, ipcErrorf(ipc.IpcErrorInvalid, "new magic headers: %w", err))
|
||||
}
|
||||
|
||||
isSameHeaderMap := map[uint32]struct{}{
|
||||
MessageInitiationType: {},
|
||||
MessageResponseType: {},
|
||||
MessageCookieReplyType: {},
|
||||
MessageTransportType: {},
|
||||
}
|
||||
|
||||
// size will be different if same values
|
||||
if len(isSameHeaderMap) != 4 {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
|
||||
MessageInitiationType,
|
||||
MessageResponseType,
|
||||
MessageCookieReplyType,
|
||||
MessageTransportType,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
newInitSize := MessageInitiationSize + tempAwg.Cfg.InitHeaderJunkSize
|
||||
|
||||
if newInitSize >= MaxSegmentSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempAwg.Cfg.InitHeaderJunkSize,
|
||||
MaxSegmentSize,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
device.awg.Cfg.InitHeaderJunkSize = tempAwg.Cfg.InitHeaderJunkSize
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.InitHeaderJunkSize != 0 {
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
newResponseSize := MessageResponseSize + tempAwg.Cfg.ResponseHeaderJunkSize
|
||||
|
||||
if newResponseSize >= MaxSegmentSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempAwg.Cfg.ResponseHeaderJunkSize,
|
||||
MaxSegmentSize,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
device.awg.Cfg.ResponseHeaderJunkSize = tempAwg.Cfg.ResponseHeaderJunkSize
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.ResponseHeaderJunkSize != 0 {
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
newCookieSize := MessageCookieReplySize + tempAwg.Cfg.CookieReplyHeaderJunkSize
|
||||
|
||||
if newCookieSize >= MaxSegmentSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`cookie reply size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempAwg.Cfg.CookieReplyHeaderJunkSize,
|
||||
MaxSegmentSize,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
device.awg.Cfg.CookieReplyHeaderJunkSize = tempAwg.Cfg.CookieReplyHeaderJunkSize
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.CookieReplyHeaderJunkSize != 0 {
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
newTransportSize := MessageTransportSize + tempAwg.Cfg.TransportHeaderJunkSize
|
||||
|
||||
if newTransportSize >= MaxSegmentSize {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`transport size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
||||
tempAwg.Cfg.TransportHeaderJunkSize,
|
||||
MaxSegmentSize,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
device.awg.Cfg.TransportHeaderJunkSize = tempAwg.Cfg.TransportHeaderJunkSize
|
||||
}
|
||||
|
||||
if tempAwg.Cfg.TransportHeaderJunkSize != 0 {
|
||||
isAwgOn = true
|
||||
}
|
||||
|
||||
isSameSizeMap := map[int]struct{}{
|
||||
newInitSize: {},
|
||||
newResponseSize: {},
|
||||
newCookieSize: {},
|
||||
newTransportSize: {},
|
||||
}
|
||||
|
||||
if len(isSameSizeMap) != 4 {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid,
|
||||
`new sizes should differ; init: %d; response: %d; cookie: %d; trans: %d`,
|
||||
newInitSize,
|
||||
newResponseSize,
|
||||
newCookieSize,
|
||||
newTransportSize,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
msgTypeToJunkSize = map[uint32]int{
|
||||
MessageInitiationType: device.awg.Cfg.InitHeaderJunkSize,
|
||||
MessageResponseType: device.awg.Cfg.ResponseHeaderJunkSize,
|
||||
MessageCookieReplyType: device.awg.Cfg.CookieReplyHeaderJunkSize,
|
||||
MessageTransportType: device.awg.Cfg.TransportHeaderJunkSize,
|
||||
}
|
||||
|
||||
packetSizeToMsgType = map[int]uint32{
|
||||
newInitSize: MessageInitiationType,
|
||||
newResponseSize: MessageResponseType,
|
||||
newCookieSize: MessageCookieReplyType,
|
||||
newTransportSize: MessageTransportType,
|
||||
}
|
||||
}
|
||||
|
||||
device.awg.IsOn.SetTo(isAwgOn)
|
||||
device.awg.JunkCreator = awg.NewJunkCreator(device.awg.Cfg)
|
||||
|
||||
if tempAwg.HandshakeHandler.IsSet {
|
||||
if err := tempAwg.HandshakeHandler.Validate(); err != nil {
|
||||
errs = append(errs, ipcErrorf(
|
||||
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
|
||||
} else {
|
||||
device.awg.HandshakeHandler = tempAwg.HandshakeHandler
|
||||
device.awg.HandshakeHandler.SpecialJunk.DefaultJunkCount = tempAwg.Cfg.JunkPacketCount
|
||||
device.version = VersionAwgSpecialHandshake
|
||||
}
|
||||
} else {
|
||||
device.version = VersionAwg
|
||||
}
|
||||
|
||||
device.awg.Mux.Unlock()
|
||||
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (device *Device) ProcessAWGPacket(size int, packet *[]byte, buffer *[MaxMessageSize]byte) (uint32, error) {
|
||||
// TODO:
|
||||
// if awg.WaitResponse.ShouldWait.IsSet() {
|
||||
// awg.WaitResponse.Channel <- struct{}{}
|
||||
// }
|
||||
|
||||
expectedMsgType, isKnownSize := packetSizeToMsgType[size]
|
||||
if !isKnownSize {
|
||||
msgType, err := device.handleTransport(size, packet, buffer)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("handle transport: %w", err)
|
||||
}
|
||||
|
||||
return msgType, nil
|
||||
}
|
||||
|
||||
junkSize := msgTypeToJunkSize[expectedMsgType]
|
||||
|
||||
// transport size can align with other header types;
|
||||
// making sure we have the right actualMsgType
|
||||
actualMsgType, err := device.getMsgType(packet, junkSize)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get msg type: %w", err)
|
||||
}
|
||||
|
||||
if actualMsgType == expectedMsgType {
|
||||
*packet = (*packet)[junkSize:]
|
||||
return actualMsgType, nil
|
||||
}
|
||||
|
||||
device.log.Verbosef("awg: transport packet lined up with another msg type")
|
||||
|
||||
msgType, err := device.handleTransport(size, packet, buffer)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("handle transport: %w", err)
|
||||
}
|
||||
|
||||
return msgType, nil
|
||||
}
|
||||
|
||||
func (device *Device) getMsgType(packet *[]byte, junkSize int) (uint32, error) {
|
||||
msgTypeValue := binary.LittleEndian.Uint32((*packet)[junkSize : junkSize+4])
|
||||
msgType, err := device.awg.GetMagicHeaderMinFor(msgTypeValue)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get magic header min: %w", err)
|
||||
}
|
||||
|
||||
return msgType, nil
|
||||
}
|
||||
|
||||
func (device *Device) handleTransport(size int, packet *[]byte, buffer *[MaxMessageSize]byte) (uint32, error) {
|
||||
junkSize := device.awg.Cfg.TransportHeaderJunkSize
|
||||
|
||||
msgType, err := device.getMsgType(packet, junkSize)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get msg type: %w", err)
|
||||
}
|
||||
|
||||
if msgType != MessageTransportType {
|
||||
// probably a junk packet
|
||||
return 0, fmt.Errorf("Received message with unknown type: %d", msgType)
|
||||
}
|
||||
|
||||
if junkSize > 0 {
|
||||
// remove junk from buffer by shifting the packet
|
||||
// this buffer is also used for decryption, so it needs to be corrected
|
||||
copy((*buffer)[:size], (*packet)[junkSize:])
|
||||
size -= junkSize
|
||||
// need to reinitialize packet as well
|
||||
(*packet) = (*packet)[:size]
|
||||
}
|
||||
|
||||
return msgType, nil
|
||||
}
|
||||
|
||||
63
device/magic-header.go
Normal file
63
device/magic-header.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type magicHeader struct {
|
||||
start uint32
|
||||
end uint32
|
||||
}
|
||||
|
||||
func newMagicHeader(spec string) (*magicHeader, error) {
|
||||
parts := strings.Split(spec, "-")
|
||||
if len(parts) < 1 || len(parts) > 2 {
|
||||
return nil, errors.New("bad format")
|
||||
}
|
||||
|
||||
start, err := strconv.ParseUint(parts[0], 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse %s: %w", parts[0], err)
|
||||
}
|
||||
|
||||
var end uint64
|
||||
if len(parts) > 1 {
|
||||
end, err = strconv.ParseUint(parts[1], 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse %s: %w", parts[1], err)
|
||||
}
|
||||
} else {
|
||||
end = start
|
||||
}
|
||||
|
||||
if end < start {
|
||||
return nil, errors.New("wrong range specified")
|
||||
}
|
||||
|
||||
return &magicHeader{
|
||||
start: uint32(start),
|
||||
end: uint32(end),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *magicHeader) GenSpec() string {
|
||||
if h.start == h.end {
|
||||
return fmt.Sprintf("%d", h.start)
|
||||
}
|
||||
return fmt.Sprintf("%d-%d", h.start, h.end)
|
||||
}
|
||||
|
||||
func (h *magicHeader) Validate(val uint32) bool {
|
||||
return h.start <= val && val <= h.end
|
||||
}
|
||||
|
||||
func (h *magicHeader) Generate() uint32 {
|
||||
high := int64(h.end - h.start + 1)
|
||||
r, _ := rand.Int(rand.Reader, big.NewInt(high))
|
||||
return h.start + uint32(r.Int64())
|
||||
}
|
||||
@@ -53,17 +53,11 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultMessageInitiationType uint32 = 1
|
||||
DefaultMessageResponseType uint32 = 2
|
||||
DefaultMessageCookieReplyType uint32 = 3
|
||||
DefaultMessageTransportType uint32 = 4
|
||||
)
|
||||
|
||||
var (
|
||||
MessageInitiationType uint32 = DefaultMessageInitiationType
|
||||
MessageResponseType uint32 = DefaultMessageResponseType
|
||||
MessageCookieReplyType uint32 = DefaultMessageCookieReplyType
|
||||
MessageTransportType uint32 = DefaultMessageTransportType
|
||||
MessageUnknownType uint32 = 0
|
||||
MessageInitiationType uint32 = 1
|
||||
MessageResponseType uint32 = 2
|
||||
MessageCookieReplyType uint32 = 3
|
||||
MessageTransportType uint32 = 4
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -82,11 +76,6 @@ const (
|
||||
MessageTransportOffsetContent = 16
|
||||
)
|
||||
|
||||
var (
|
||||
packetSizeToMsgType map[int]uint32
|
||||
msgTypeToJunkSize map[uint32]int
|
||||
)
|
||||
|
||||
/* Type is an 8-bit field, followed by 3 nul bytes,
|
||||
* by marshalling the messages in little-endian byteorder
|
||||
* we can treat these as a 32-bit unsigned int (for now)
|
||||
@@ -205,17 +194,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
||||
|
||||
handshake.mixHash(handshake.remoteStatic[:])
|
||||
|
||||
msgType := DefaultMessageInitiationType
|
||||
if device.isAWG() {
|
||||
device.awg.Mux.RLock()
|
||||
msgType, err = device.awg.GetMsgType(DefaultMessageInitiationType)
|
||||
if err != nil {
|
||||
device.awg.Mux.RUnlock()
|
||||
return nil, fmt.Errorf("get message type: %w", err)
|
||||
}
|
||||
|
||||
device.awg.Mux.RUnlock()
|
||||
}
|
||||
msgType := device.headers.init.Generate()
|
||||
|
||||
msg := MessageInitiation{
|
||||
Type: msgType,
|
||||
@@ -274,13 +253,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||
chainKey [blake2s.Size]byte
|
||||
)
|
||||
|
||||
device.awg.Mux.RLock()
|
||||
|
||||
if msg.Type != MessageInitiationType {
|
||||
device.awg.Mux.RUnlock()
|
||||
return nil
|
||||
}
|
||||
device.awg.Mux.RUnlock()
|
||||
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
@@ -395,19 +370,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||
}
|
||||
|
||||
var msg MessageResponse
|
||||
if device.isAWG() {
|
||||
device.awg.Mux.RLock()
|
||||
msg.Type, err = device.awg.GetMsgType(DefaultMessageResponseType)
|
||||
if err != nil {
|
||||
device.awg.Mux.RUnlock()
|
||||
return nil, fmt.Errorf("get message type: %w", err)
|
||||
}
|
||||
|
||||
device.awg.Mux.RUnlock()
|
||||
} else {
|
||||
msg.Type = DefaultMessageResponseType
|
||||
}
|
||||
|
||||
msg.Type = device.headers.response.Generate()
|
||||
msg.Sender = handshake.localIndex
|
||||
msg.Receiver = handshake.remoteIndex
|
||||
|
||||
@@ -457,13 +420,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||
}
|
||||
|
||||
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||
device.awg.Mux.RLock()
|
||||
|
||||
if msg.Type != MessageResponseType {
|
||||
device.awg.Mux.RUnlock()
|
||||
return nil
|
||||
}
|
||||
device.awg.Mux.RUnlock()
|
||||
|
||||
// lookup handshake by receiver
|
||||
|
||||
|
||||
140
device/obf.go
Normal file
140
device/obf.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type obfBuilder func(val string) (obf, error)
|
||||
|
||||
var obfBuilders = map[string]obfBuilder{
|
||||
"b": newBytesObf,
|
||||
"t": newTimestampObf,
|
||||
"r": newRandObf,
|
||||
"rc": newRandCharObf,
|
||||
"rd": newRandDigitsObf,
|
||||
"d": newDataObf,
|
||||
"ds": newDataStringObf,
|
||||
"dz": newDataSizeObf,
|
||||
}
|
||||
|
||||
type obf interface {
|
||||
Obfuscate(dst, src []byte)
|
||||
Deobfuscate(dst, src []byte) bool
|
||||
ObfuscatedLen(srcLen int) int
|
||||
DeobfuscatedLen(srcLen int) int
|
||||
}
|
||||
|
||||
type obfChain struct {
|
||||
Spec string
|
||||
obfs []obf
|
||||
}
|
||||
|
||||
func newObfChain(spec string) (*obfChain, error) {
|
||||
var (
|
||||
obfs []obf
|
||||
errs []error
|
||||
)
|
||||
|
||||
remaining := spec[:]
|
||||
for {
|
||||
start := strings.IndexByte(remaining, '<')
|
||||
if start == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
end := strings.IndexByte(remaining[start:], '>')
|
||||
if end == -1 {
|
||||
return nil, errors.New("missing enclosing >")
|
||||
}
|
||||
end += start
|
||||
|
||||
tag := remaining[start+1 : end]
|
||||
parts := strings.Fields(tag)
|
||||
if len(parts) == 0 {
|
||||
errs = append(errs, errors.New("empty tag"))
|
||||
remaining = remaining[end+1:]
|
||||
continue
|
||||
}
|
||||
|
||||
key := parts[0]
|
||||
builder, ok := obfBuilders[key]
|
||||
if !ok {
|
||||
errs = append(errs, fmt.Errorf("unknown tag <%s>", key))
|
||||
remaining = remaining[end+1:]
|
||||
continue
|
||||
}
|
||||
|
||||
val := ""
|
||||
if len(parts) > 1 {
|
||||
val = parts[1]
|
||||
}
|
||||
|
||||
o, err := builder(val)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to build <%s>: %w", key, err))
|
||||
remaining = remaining[end+1:]
|
||||
continue
|
||||
}
|
||||
|
||||
obfs = append(obfs, o)
|
||||
remaining = remaining[end+1:]
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return nil, errors.Join(errs...)
|
||||
}
|
||||
|
||||
return &obfChain{
|
||||
Spec: spec,
|
||||
obfs: obfs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *obfChain) Obfuscate(dst, src []byte) {
|
||||
written := 0
|
||||
for _, o := range c.obfs {
|
||||
obfLen := o.ObfuscatedLen(len(src))
|
||||
o.Obfuscate(dst[written:written+obfLen], src)
|
||||
written += obfLen
|
||||
}
|
||||
}
|
||||
|
||||
func (c *obfChain) Deobfuscate(dst, src []byte) bool {
|
||||
dynamicLen := len(src) - c.ObfuscatedLen(0)
|
||||
|
||||
written, read := 0, 0
|
||||
|
||||
for _, o := range c.obfs {
|
||||
deobfLen := o.DeobfuscatedLen(dynamicLen)
|
||||
obfLen := o.ObfuscatedLen(deobfLen)
|
||||
|
||||
if !o.Deobfuscate(dst[written:written+deobfLen], src[read:read+obfLen]) {
|
||||
return false
|
||||
}
|
||||
|
||||
written += deobfLen
|
||||
read += obfLen
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *obfChain) ObfuscatedLen(n int) int {
|
||||
total := 0
|
||||
for _, o := range c.obfs {
|
||||
total += o.ObfuscatedLen(n)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func (c *obfChain) DeobfuscatedLen(n int) int {
|
||||
dynamicLen := n - c.ObfuscatedLen(0)
|
||||
|
||||
total := 0
|
||||
for _, o := range c.obfs {
|
||||
total += o.DeobfuscatedLen(dynamicLen)
|
||||
}
|
||||
return total
|
||||
}
|
||||
47
device/obf_bytes.go
Normal file
47
device/obf_bytes.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func newBytesObf(val string) (obf, error) {
|
||||
val = strings.TrimPrefix(val, "0x")
|
||||
|
||||
if len(val) == 0 {
|
||||
return nil, errors.New("empty argument")
|
||||
}
|
||||
|
||||
if len(val)%2 != 0 {
|
||||
return nil, errors.New("odd amount of symbols")
|
||||
}
|
||||
|
||||
bytes, err := hex.DecodeString(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &bytesObf{data: bytes}, nil
|
||||
}
|
||||
|
||||
type bytesObf struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
func (o *bytesObf) Obfuscate(dst, src []byte) {
|
||||
copy(dst, o.data)
|
||||
}
|
||||
|
||||
func (o *bytesObf) Deobfuscate(dst, src []byte) bool {
|
||||
return bytes.Equal(o.data, src[:o.ObfuscatedLen(0)])
|
||||
}
|
||||
|
||||
func (o *bytesObf) ObfuscatedLen(srcLen int) int {
|
||||
return len(o.data)
|
||||
}
|
||||
|
||||
func (o *bytesObf) DeobfuscatedLen(srcLen int) int {
|
||||
return 0
|
||||
}
|
||||
25
device/obf_data.go
Normal file
25
device/obf_data.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package device
|
||||
|
||||
func newDataObf(val string) (obf, error) {
|
||||
return &dataObf{}, nil
|
||||
}
|
||||
|
||||
type dataObf struct {
|
||||
}
|
||||
|
||||
func (obf *dataObf) Obfuscate(dst, src []byte) {
|
||||
copy(dst, src)
|
||||
}
|
||||
|
||||
func (obf *dataObf) Deobfuscate(dst, src []byte) bool {
|
||||
copy(dst, src)
|
||||
return true
|
||||
}
|
||||
|
||||
func (o *dataObf) ObfuscatedLen(n int) int {
|
||||
return n
|
||||
}
|
||||
|
||||
func (o *dataObf) DeobfuscatedLen(n int) int {
|
||||
return n
|
||||
}
|
||||
38
device/obf_datasize.go
Normal file
38
device/obf_datasize.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package device
|
||||
|
||||
import "strconv"
|
||||
|
||||
func newDataSizeObf(val string) (obf, error) {
|
||||
length, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &dataSizeObf{
|
||||
length: length,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type dataSizeObf struct {
|
||||
length int
|
||||
}
|
||||
|
||||
func (o *dataSizeObf) Obfuscate(dst, src []byte) {
|
||||
srcLen := len(src)
|
||||
for i := o.length - 1; i >= 0; i-- {
|
||||
dst[i] = byte(srcLen & 0xFF)
|
||||
srcLen >>= 8
|
||||
}
|
||||
}
|
||||
|
||||
func (o *dataSizeObf) Deobfuscate(dst, src []byte) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (o *dataSizeObf) ObfuscatedLen(srcLen int) int {
|
||||
return o.length
|
||||
}
|
||||
|
||||
func (o *dataSizeObf) DeobfuscatedLen(srcLen int) int {
|
||||
return 0
|
||||
}
|
||||
29
device/obf_datastring.go
Normal file
29
device/obf_datastring.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
func newDataStringObf(val string) (obf, error) {
|
||||
return &dataStringObf{}, nil
|
||||
}
|
||||
|
||||
type dataStringObf struct {
|
||||
}
|
||||
|
||||
func (o *dataStringObf) Obfuscate(dst, src []byte) {
|
||||
base64.RawStdEncoding.Encode(dst, src)
|
||||
}
|
||||
|
||||
func (o *dataStringObf) Deobfuscate(dst, src []byte) bool {
|
||||
base64.RawStdEncoding.Decode(dst, src)
|
||||
return true
|
||||
}
|
||||
|
||||
func (o *dataStringObf) ObfuscatedLen(n int) int {
|
||||
return base64.RawStdEncoding.EncodedLen(n)
|
||||
}
|
||||
|
||||
func (o *dataStringObf) DeobfuscatedLen(n int) int {
|
||||
return base64.RawStdEncoding.DecodedLen(n)
|
||||
}
|
||||
39
device/obf_rand.go
Normal file
39
device/obf_rand.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func newRandObf(val string) (obf, error) {
|
||||
length, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &randObf{
|
||||
length: length,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type randObf struct {
|
||||
length int
|
||||
}
|
||||
|
||||
func (o *randObf) Obfuscate(dst, src []byte) {
|
||||
rand.Read(dst[:o.length])
|
||||
}
|
||||
|
||||
func (o *randObf) Deobfuscate(dst, src []byte) bool {
|
||||
// there is no way to validate randomness :)
|
||||
// assume that it is always true
|
||||
return true
|
||||
}
|
||||
|
||||
func (o *randObf) ObfuscatedLen(n int) int {
|
||||
return o.length
|
||||
}
|
||||
|
||||
func (o *randObf) DeobfuscatedLen(n int) int {
|
||||
return 0
|
||||
}
|
||||
48
device/obf_randchars.go
Normal file
48
device/obf_randchars.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"strconv"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
const chars52 = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
func newRandCharObf(val string) (obf, error) {
|
||||
length, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &randCharObf{
|
||||
length: length,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type randCharObf struct {
|
||||
length int
|
||||
}
|
||||
|
||||
func (o *randCharObf) Obfuscate(dst, src []byte) {
|
||||
rand.Read(dst[:o.length])
|
||||
for i := range dst[:o.length] {
|
||||
dst[i] = chars52[dst[i]%52]
|
||||
}
|
||||
}
|
||||
|
||||
func (o *randCharObf) Deobfuscate(dst, src []byte) bool {
|
||||
for _, b := range src[:o.length] {
|
||||
if !unicode.IsLetter(rune(b)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (o *randCharObf) ObfuscatedLen(n int) int {
|
||||
return o.length
|
||||
}
|
||||
|
||||
func (o *randCharObf) DeobfuscatedLen(n int) int {
|
||||
return 0
|
||||
}
|
||||
48
device/obf_randdigits.go
Normal file
48
device/obf_randdigits.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"strconv"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
const digits10 = "0123456789"
|
||||
|
||||
func newRandDigitsObf(val string) (obf, error) {
|
||||
length, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &randDigitObf{
|
||||
length: length,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type randDigitObf struct {
|
||||
length int
|
||||
}
|
||||
|
||||
func (o *randDigitObf) Obfuscate(dst, src []byte) {
|
||||
rand.Read(dst[:o.length])
|
||||
for i := range dst[:o.length] {
|
||||
dst[i] = digits10[dst[i]%10]
|
||||
}
|
||||
}
|
||||
|
||||
func (o *randDigitObf) Deobfuscate(dst, src []byte) bool {
|
||||
for _, b := range src[:o.length] {
|
||||
if !unicode.IsDigit(rune(b)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (o *randDigitObf) ObfuscatedLen(n int) int {
|
||||
return o.length
|
||||
}
|
||||
|
||||
func (o *randDigitObf) DeobfuscatedLen(n int) int {
|
||||
return 0
|
||||
}
|
||||
31
device/obf_timestamp.go
Normal file
31
device/obf_timestamp.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newTimestampObf(_ string) (obf, error) {
|
||||
return ×tampObf{}, nil
|
||||
}
|
||||
|
||||
type timestampObf struct{}
|
||||
|
||||
func (o *timestampObf) Obfuscate(dst, src []byte) {
|
||||
t := uint32(time.Now().Unix())
|
||||
binary.BigEndian.PutUint32(dst, t)
|
||||
}
|
||||
|
||||
func (o *timestampObf) Deobfuscate(dst, src []byte) bool {
|
||||
// replay attack check?
|
||||
// requires time to be always synchronized
|
||||
return true
|
||||
}
|
||||
|
||||
func (o *timestampObf) ObfuscatedLen(n int) int {
|
||||
return 4
|
||||
}
|
||||
|
||||
func (o *timestampObf) DeobfuscatedLen(n int) int {
|
||||
return 0
|
||||
}
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg"
|
||||
)
|
||||
|
||||
type Peer struct {
|
||||
@@ -114,16 +113,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
func (peer *Peer) SendAndCountBuffers(buffers [][]byte) error {
|
||||
err := peer.SendBuffers(buffers)
|
||||
if err == nil {
|
||||
awg.PacketCounter.Add(uint64(len(buffers)))
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
||||
peer.device.net.RLock()
|
||||
defer peer.device.net.RUnlock()
|
||||
|
||||
@@ -97,13 +97,13 @@ func (device *Device) RoutineReceiveIncoming(
|
||||
elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
|
||||
)
|
||||
|
||||
for i := range bufsArrs {
|
||||
for i := range maxBatchSize {
|
||||
bufsArrs[i] = device.GetMessageBuffer()
|
||||
bufs[i] = bufsArrs[i][:]
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for i := 0; i < maxBatchSize; i++ {
|
||||
for i := range maxBatchSize {
|
||||
if bufsArrs[i] != nil {
|
||||
device.PutMessageBuffer(bufsArrs[i])
|
||||
}
|
||||
@@ -129,7 +129,6 @@ func (device *Device) RoutineReceiveIncoming(
|
||||
}
|
||||
deathSpiral = 0
|
||||
|
||||
device.awg.Mux.RLock()
|
||||
// handle each packet in the batch
|
||||
for i, size := range sizes[:count] {
|
||||
if size < MinMessageSize {
|
||||
@@ -138,16 +137,12 @@ func (device *Device) RoutineReceiveIncoming(
|
||||
|
||||
// check size of packet
|
||||
packet := bufsArrs[i][:size]
|
||||
var msgType uint32
|
||||
if device.isAWG() {
|
||||
msgType, err = device.ProcessAWGPacket(size, &packet, bufsArrs[i])
|
||||
|
||||
if err != nil {
|
||||
device.log.Verbosef("awg: process packet: %v", err)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
msgType = binary.LittleEndian.Uint32(packet[:4])
|
||||
// get message padding and type based on information from S1-S4 and H1-H4
|
||||
msgType, padding := device.DeterminePacketTypeAndPadding(packet, MessageUnknownType)
|
||||
if padding > 0 {
|
||||
copy(packet, packet[padding:])
|
||||
packet = packet[:len(packet)-padding]
|
||||
}
|
||||
|
||||
switch msgType {
|
||||
@@ -233,7 +228,6 @@ func (device *Device) RoutineReceiveIncoming(
|
||||
default:
|
||||
}
|
||||
}
|
||||
device.awg.Mux.RUnlock()
|
||||
for peer, elemsContainer := range elemsByPeer {
|
||||
if peer.isRunning.Load() {
|
||||
peer.queue.inbound.c <- elemsContainer
|
||||
@@ -291,9 +285,6 @@ func (device *Device) RoutineHandshake(id int) {
|
||||
device.log.Verbosef("Routine: handshake worker %d - started", id)
|
||||
|
||||
for elem := range device.queue.handshake.c {
|
||||
|
||||
device.awg.Mux.RLock()
|
||||
|
||||
// handle cookie fields and ratelimiting
|
||||
|
||||
switch elem.msgType {
|
||||
@@ -450,7 +441,6 @@ func (device *Device) RoutineHandshake(id int) {
|
||||
peer.SendKeepalive()
|
||||
}
|
||||
skip:
|
||||
device.awg.Mux.RUnlock()
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
}
|
||||
}
|
||||
@@ -569,3 +559,57 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
||||
device.PutInboundElementsContainer(elemsContainer)
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) DeterminePacketTypeAndPadding(packet []byte, expectedType uint32) (uint32, int) {
|
||||
size := len(packet)
|
||||
|
||||
if expectedType == MessageUnknownType || expectedType == MessageInitiationType {
|
||||
padding := device.paddings.init
|
||||
header := device.headers.init
|
||||
|
||||
if size == padding+MessageInitiationSize {
|
||||
data := packet[padding:]
|
||||
if header.Validate(binary.LittleEndian.Uint32(data)) {
|
||||
return MessageInitiationType, padding
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if expectedType == MessageUnknownType || expectedType == MessageResponseType {
|
||||
padding := device.paddings.response
|
||||
header := device.headers.response
|
||||
|
||||
if size == padding+MessageResponseSize {
|
||||
data := packet[padding:]
|
||||
if header.Validate(binary.LittleEndian.Uint32(data)) {
|
||||
return MessageResponseType, padding
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if expectedType == MessageUnknownType || expectedType == MessageCookieReplyType {
|
||||
padding := device.paddings.cookie
|
||||
header := device.headers.cookie
|
||||
|
||||
if size == padding+MessageCookieReplySize {
|
||||
data := packet[padding:]
|
||||
if header.Validate(binary.LittleEndian.Uint32(data)) {
|
||||
return MessageCookieReplyType, padding
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if expectedType == MessageUnknownType || expectedType == MessageTransportType {
|
||||
padding := device.paddings.transport
|
||||
header := device.headers.transport
|
||||
|
||||
if size >= padding+MessageTransportHeaderSize {
|
||||
data := packet[padding:]
|
||||
if header.Validate(binary.LittleEndian.Uint32(data)) {
|
||||
return MessageTransportType, padding
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return MessageUnknownType, 0
|
||||
}
|
||||
|
||||
138
device/send.go
138
device/send.go
@@ -7,8 +7,10 @@ package device
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -123,41 +125,28 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
|
||||
return err
|
||||
}
|
||||
|
||||
var sendBuffer [][]byte
|
||||
|
||||
// so only packet processed for cookie generation
|
||||
var junkedHeader []byte
|
||||
if peer.device.version >= VersionAwg {
|
||||
var junks [][]byte
|
||||
if peer.device.version == VersionAwgSpecialHandshake {
|
||||
peer.device.awg.Mux.RLock()
|
||||
// set junks depending on packet type
|
||||
junks = peer.device.awg.HandshakeHandler.GenerateSpecialJunk()
|
||||
if junks != nil {
|
||||
peer.device.log.Verbosef("%v - Special junks sent", peer)
|
||||
}
|
||||
peer.device.awg.Mux.RUnlock()
|
||||
} else {
|
||||
junks = make([][]byte, 0, peer.device.awg.Cfg.JunkPacketCount)
|
||||
for _, ipacket := range peer.device.ipackets {
|
||||
if ipacket != nil {
|
||||
buf := make([]byte, ipacket.ObfuscatedLen(0))
|
||||
ipacket.Obfuscate(buf, nil)
|
||||
sendBuffer = append(sendBuffer, buf)
|
||||
}
|
||||
peer.device.awg.Mux.RLock()
|
||||
peer.device.awg.JunkCreator.CreateJunkPackets(&junks)
|
||||
peer.device.awg.Mux.RUnlock()
|
||||
}
|
||||
|
||||
if len(junks) > 0 {
|
||||
err = peer.SendBuffers(junks)
|
||||
jc := peer.device.junk.count
|
||||
jmin := peer.device.junk.min
|
||||
jmax := peer.device.junk.max
|
||||
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
for i := 0; i < jc; i++ {
|
||||
nBig, _ := rand.Int(rand.Reader, big.NewInt(int64(jmax-jmin+1)))
|
||||
n := int(nBig.Int64()) + jmin
|
||||
|
||||
junkedHeader, err = peer.device.awg.CreateInitHeaderJunk()
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - %v", peer, err)
|
||||
return err
|
||||
}
|
||||
buf := make([]byte, n)
|
||||
rand.Read(buf)
|
||||
sendBuffer = append(sendBuffer, buf)
|
||||
}
|
||||
|
||||
var buf [MessageInitiationSize]byte
|
||||
@@ -165,14 +154,20 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
||||
binary.Write(writer, binary.LittleEndian, msg)
|
||||
packet := writer.Bytes()
|
||||
peer.cookieGenerator.AddMacs(packet)
|
||||
junkedHeader = append(junkedHeader, packet...)
|
||||
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
sendBuffer = append(sendBuffer, junkedHeader)
|
||||
if padding := peer.device.paddings.init; padding > 0 {
|
||||
buf := make([]byte, padding+len(packet))
|
||||
rand.Read(buf[:padding])
|
||||
copy(buf[padding:], packet)
|
||||
packet = buf
|
||||
}
|
||||
|
||||
err = peer.SendAndCountBuffers(sendBuffer)
|
||||
sendBuffer = append(sendBuffer, packet)
|
||||
|
||||
err = peer.SendBuffers(sendBuffer)
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
|
||||
}
|
||||
@@ -194,19 +189,12 @@ func (peer *Peer) SendHandshakeResponse() error {
|
||||
return err
|
||||
}
|
||||
|
||||
junkedHeader, err := peer.device.awg.CreateResponseHeaderJunk()
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - %v", peer, err)
|
||||
return err
|
||||
}
|
||||
|
||||
var buf [MessageResponseSize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
|
||||
binary.Write(writer, binary.LittleEndian, response)
|
||||
packet := writer.Bytes()
|
||||
peer.cookieGenerator.AddMacs(packet)
|
||||
junkedHeader = append(junkedHeader, packet...)
|
||||
|
||||
err = peer.BeginSymmetricSession()
|
||||
if err != nil {
|
||||
@@ -218,32 +206,26 @@ func (peer *Peer) SendHandshakeResponse() error {
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
if padding := peer.device.paddings.response; padding > 0 {
|
||||
buf := make([]byte, padding+len(packet))
|
||||
rand.Read(buf[:padding])
|
||||
copy(buf[padding:], packet)
|
||||
packet = buf
|
||||
}
|
||||
|
||||
// TODO: allocation could be avoided
|
||||
err = peer.SendAndCountBuffers([][]byte{junkedHeader})
|
||||
err = peer.SendBuffers([][]byte{packet})
|
||||
if err != nil {
|
||||
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (device *Device) SendHandshakeCookie(
|
||||
initiatingElem *QueueHandshakeElement,
|
||||
) error {
|
||||
func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
|
||||
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
|
||||
|
||||
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
|
||||
msgType := DefaultMessageCookieReplyType
|
||||
if device.isAWG() {
|
||||
device.awg.Mux.RLock()
|
||||
|
||||
var err error
|
||||
msgType, err = device.awg.GetMsgType(DefaultMessageCookieReplyType)
|
||||
device.awg.Mux.RUnlock()
|
||||
if err != nil {
|
||||
device.log.Errorf("Get message type for cookie reply: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
msgType := device.headers.cookie.Generate()
|
||||
|
||||
reply, err := device.cookieChecker.CreateReply(
|
||||
initiatingElem.packet,
|
||||
@@ -256,19 +238,20 @@ func (device *Device) SendHandshakeCookie(
|
||||
return err
|
||||
}
|
||||
|
||||
junkedHeader, err := device.awg.CreateCookieReplyHeaderJunk()
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - %v", device, err)
|
||||
return err
|
||||
}
|
||||
|
||||
var buf [MessageCookieReplySize]byte
|
||||
writer := bytes.NewBuffer(buf[:0])
|
||||
binary.Write(writer, binary.LittleEndian, reply)
|
||||
packet := writer.Bytes()
|
||||
|
||||
if padding := device.paddings.cookie; padding > 0 {
|
||||
buf := make([]byte, padding+len(packet))
|
||||
rand.Read(buf[:padding])
|
||||
copy(buf[padding:], packet)
|
||||
packet = buf
|
||||
}
|
||||
|
||||
junkedHeader = append(junkedHeader, writer.Bytes()...)
|
||||
// TODO: allocation could be avoided
|
||||
device.net.bind.Send([][]byte{junkedHeader}, initiatingElem.endpoint)
|
||||
device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -532,18 +515,7 @@ func (device *Device) RoutineEncryption(id int) {
|
||||
fieldReceiver := header[4:8]
|
||||
fieldNonce := header[8:16]
|
||||
|
||||
msgType := DefaultMessageTransportType
|
||||
if device.isAWG() {
|
||||
device.awg.Mux.RLock()
|
||||
|
||||
var err error
|
||||
msgType, err = device.awg.GetMsgType(DefaultMessageTransportType)
|
||||
device.awg.Mux.RUnlock()
|
||||
if err != nil {
|
||||
device.log.Errorf("get message type for transport: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
msgType := device.headers.transport.Generate()
|
||||
|
||||
binary.LittleEndian.PutUint32(fieldType, msgType)
|
||||
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
|
||||
@@ -603,13 +575,15 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
||||
if len(elem.packet) != MessageKeepaliveSize {
|
||||
dataSent = true
|
||||
|
||||
junkedHeader, err := device.awg.CreateTransportHeaderJunk(len(elem.packet))
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - %v", device, err)
|
||||
continue
|
||||
if padding := device.paddings.transport; padding > 0 {
|
||||
// elem.packet is stored at the start of elem.buffer
|
||||
// with zero padding
|
||||
for i := len(elem.packet) - 1; i >= 0; i-- {
|
||||
elem.buffer[i+padding] = elem.buffer[i]
|
||||
}
|
||||
rand.Read(elem.buffer[:padding])
|
||||
elem.packet = elem.buffer[:padding+len(elem.packet)]
|
||||
}
|
||||
|
||||
elem.packet = append(junkedHeader, elem.packet...)
|
||||
}
|
||||
bufs = append(bufs, elem.packet)
|
||||
}
|
||||
@@ -617,7 +591,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
||||
peer.timersAnyAuthenticatedPacketTraversal()
|
||||
peer.timersAnyAuthenticatedPacketSent()
|
||||
|
||||
err := peer.SendAndCountBuffers(bufs)
|
||||
err := peer.SendBuffers(bufs)
|
||||
if dataSent {
|
||||
peer.timersDataSent()
|
||||
}
|
||||
|
||||
299
device/uapi.go
299
device/uapi.go
@@ -18,7 +18,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/amnezia-vpn/amneziawg-go/device/awg"
|
||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
||||
)
|
||||
|
||||
@@ -98,42 +97,53 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||
sendf("fwmark=%d", device.net.fwmark)
|
||||
}
|
||||
|
||||
if device.isAWG() {
|
||||
if device.awg.Cfg.JunkPacketCount != 0 {
|
||||
sendf("jc=%d", device.awg.Cfg.JunkPacketCount)
|
||||
}
|
||||
if device.awg.Cfg.JunkPacketMinSize != 0 {
|
||||
sendf("jmin=%d", device.awg.Cfg.JunkPacketMinSize)
|
||||
}
|
||||
if device.awg.Cfg.JunkPacketMaxSize != 0 {
|
||||
sendf("jmax=%d", device.awg.Cfg.JunkPacketMaxSize)
|
||||
}
|
||||
if device.awg.Cfg.InitHeaderJunkSize != 0 {
|
||||
sendf("s1=%d", device.awg.Cfg.InitHeaderJunkSize)
|
||||
}
|
||||
if device.awg.Cfg.ResponseHeaderJunkSize != 0 {
|
||||
sendf("s2=%d", device.awg.Cfg.ResponseHeaderJunkSize)
|
||||
}
|
||||
if device.awg.Cfg.CookieReplyHeaderJunkSize != 0 {
|
||||
sendf("s3=%d", device.awg.Cfg.CookieReplyHeaderJunkSize)
|
||||
}
|
||||
if device.awg.Cfg.TransportHeaderJunkSize != 0 {
|
||||
sendf("s4=%d", device.awg.Cfg.TransportHeaderJunkSize)
|
||||
}
|
||||
for i, magicHeader := range device.awg.Cfg.MagicHeaders.Values {
|
||||
if magicHeader.Min > 4 {
|
||||
if magicHeader.Min == magicHeader.Max {
|
||||
sendf("h%d=%d", i+1, magicHeader.Min)
|
||||
continue
|
||||
}
|
||||
if device.junk.count != 0 {
|
||||
sendf("jc=%d", device.junk.count)
|
||||
}
|
||||
|
||||
sendf("h%d=%d-%d", i+1, magicHeader.Min, magicHeader.Max)
|
||||
}
|
||||
}
|
||||
if device.junk.min != 0 {
|
||||
sendf("jmin=%d", device.junk.min)
|
||||
}
|
||||
|
||||
specialJunkIpcFields := device.awg.HandshakeHandler.SpecialJunk.IpcGetFields()
|
||||
for _, field := range specialJunkIpcFields {
|
||||
sendf("%s=%s", field.Key, field.Value)
|
||||
if device.junk.max != 0 {
|
||||
sendf("jmax=%d", device.junk.max)
|
||||
}
|
||||
|
||||
if device.paddings.init != 0 {
|
||||
sendf("s1=%d", device.paddings.init)
|
||||
}
|
||||
|
||||
if device.paddings.response != 0 {
|
||||
sendf("s2=%d", device.paddings.response)
|
||||
}
|
||||
|
||||
if device.paddings.cookie != 0 {
|
||||
sendf("s3=%d", device.paddings.cookie)
|
||||
}
|
||||
|
||||
if device.paddings.transport != 0 {
|
||||
sendf("s4=%d", device.paddings.transport)
|
||||
}
|
||||
|
||||
if device.headers.init != nil {
|
||||
sendf("h1=%s", device.headers.init.GenSpec())
|
||||
}
|
||||
|
||||
if device.headers.response != nil {
|
||||
sendf("h2=%s", device.headers.response.GenSpec())
|
||||
}
|
||||
|
||||
if device.headers.cookie != nil {
|
||||
sendf("h3=%s", device.headers.cookie.GenSpec())
|
||||
}
|
||||
|
||||
if device.headers.transport != nil {
|
||||
sendf("h4=%s", device.headers.transport.GenSpec())
|
||||
}
|
||||
|
||||
for i, ipacket := range device.ipackets {
|
||||
if ipacket != nil {
|
||||
sendf("i%d=%s", i+1, ipacket.Spec)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,20 +197,18 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
ipcDev := new(ipcSetDevice)
|
||||
peer := new(ipcSetPeer)
|
||||
deviceConfig := true
|
||||
|
||||
tempAwg := awg.Protocol{}
|
||||
tempAwg.Cfg.MagicHeaders.Values = make([]awg.MagicHeader, 4)
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
// Blank line means terminate operation.
|
||||
err := device.handlePostConfig(&tempAwg)
|
||||
err := ipcDev.mergeWithDevice(device)
|
||||
if err != nil {
|
||||
return err
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to merge with device: %w", err)
|
||||
}
|
||||
peer.handlePostConfig()
|
||||
return nil
|
||||
@@ -229,7 +237,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||
|
||||
var err error
|
||||
if deviceConfig {
|
||||
err = device.handleDeviceLine(key, value, &tempAwg)
|
||||
err = device.handleDeviceLine(key, value)
|
||||
} else {
|
||||
err = device.handlePeerLine(peer, key, value)
|
||||
}
|
||||
@@ -237,9 +245,9 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = device.handlePostConfig(&tempAwg)
|
||||
err = ipcDev.mergeWithDevice(device)
|
||||
if err != nil {
|
||||
return err
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to merge with device: %w", err)
|
||||
}
|
||||
peer.handlePostConfig()
|
||||
|
||||
@@ -249,7 +257,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol) error {
|
||||
func (device *Device) handleDeviceLine(key, value string) error {
|
||||
switch key {
|
||||
case "private_key":
|
||||
var sk NoisePrivateKey
|
||||
@@ -300,112 +308,145 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
|
||||
device.RemoveAllPeers()
|
||||
|
||||
case "jc":
|
||||
junkPacketCount, err := strconv.Atoi(value)
|
||||
jc, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_count %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse jc: %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk_packet_count")
|
||||
tempAwg.Cfg.JunkPacketCount = junkPacketCount
|
||||
tempAwg.Cfg.IsSet = true
|
||||
if jc <= 0 {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "jc must be a positive value")
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk count")
|
||||
device.junk.count = jc
|
||||
|
||||
case "jmin":
|
||||
junkPacketMinSize, err := strconv.Atoi(value)
|
||||
jmin, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_min_size %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse jmin: %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
|
||||
tempAwg.Cfg.JunkPacketMinSize = junkPacketMinSize
|
||||
tempAwg.Cfg.IsSet = true
|
||||
if jmin <= 0 {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "jmin must be a positive value")
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk min")
|
||||
device.junk.min = jmin
|
||||
|
||||
case "jmax":
|
||||
junkPacketMaxSize, err := strconv.Atoi(value)
|
||||
jmax, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse junk_packet_max_size %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse jmax: %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
|
||||
tempAwg.Cfg.JunkPacketMaxSize = junkPacketMaxSize
|
||||
tempAwg.Cfg.IsSet = true
|
||||
if jmax <= 0 {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "jmax must be a positive value")
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating junk max")
|
||||
device.junk.max = jmax
|
||||
|
||||
case "s1":
|
||||
initPacketJunkSize, err := strconv.Atoi(value)
|
||||
padding, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse init_packet_junk_size %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse s1: %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
|
||||
tempAwg.Cfg.InitHeaderJunkSize = initPacketJunkSize
|
||||
tempAwg.Cfg.IsSet = true
|
||||
if padding < 0 {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "s1 must be non-negative")
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating s1 padding")
|
||||
device.paddings.init = padding
|
||||
|
||||
case "s2":
|
||||
responsePacketJunkSize, err := strconv.Atoi(value)
|
||||
padding, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse response_packet_junk_size %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse s2: %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
|
||||
tempAwg.Cfg.ResponseHeaderJunkSize = responsePacketJunkSize
|
||||
tempAwg.Cfg.IsSet = true
|
||||
if padding < 0 {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "s2 must be non-negative")
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating s2 padding")
|
||||
device.paddings.response = padding
|
||||
|
||||
case "s3":
|
||||
cookieReplyPacketJunkSize, err := strconv.Atoi(value)
|
||||
padding, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse cookie_reply_packet_junk_size %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse s3: %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating cookie_reply_packet_junk_size")
|
||||
tempAwg.Cfg.CookieReplyHeaderJunkSize = cookieReplyPacketJunkSize
|
||||
tempAwg.Cfg.IsSet = true
|
||||
if padding < 0 {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "s3 must be non-negative")
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating s3 padding")
|
||||
device.paddings.cookie = padding
|
||||
|
||||
case "s4":
|
||||
transportPacketJunkSize, err := strconv.Atoi(value)
|
||||
padding, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "parse transport_packet_junk_size %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse s4: %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating transport_packet_junk_size")
|
||||
tempAwg.Cfg.TransportHeaderJunkSize = transportPacketJunkSize
|
||||
tempAwg.Cfg.IsSet = true
|
||||
if padding < 0 {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "s4 must be non-negative")
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating s4 padding")
|
||||
device.paddings.transport = padding
|
||||
|
||||
case "h1":
|
||||
initMagicHeader, err := awg.ParseMagicHeader(key, value)
|
||||
header, err := newMagicHeader(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse H1: %w", err)
|
||||
}
|
||||
device.headers.init = header
|
||||
|
||||
tempAwg.Cfg.MagicHeaders.Values[0] = initMagicHeader
|
||||
tempAwg.Cfg.IsSet = true
|
||||
case "h2":
|
||||
responseMagicHeader, err := awg.ParseMagicHeader(key, value)
|
||||
header, err := newMagicHeader(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse H2: %w", err)
|
||||
}
|
||||
device.headers.response = header
|
||||
|
||||
tempAwg.Cfg.MagicHeaders.Values[1] = responseMagicHeader
|
||||
tempAwg.Cfg.IsSet = true
|
||||
case "h3":
|
||||
cookieReplyMagicHeader, err := awg.ParseMagicHeader(key, value)
|
||||
header, err := newMagicHeader(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse H3: %w", err)
|
||||
}
|
||||
device.headers.cookie = header
|
||||
|
||||
tempAwg.Cfg.MagicHeaders.Values[2] = cookieReplyMagicHeader
|
||||
tempAwg.Cfg.IsSet = true
|
||||
case "h4":
|
||||
transportMagicHeader, err := awg.ParseMagicHeader(key, value)
|
||||
header, err := newMagicHeader(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "uapi: %w", err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse H4: %w", err)
|
||||
}
|
||||
device.headers.transport = header
|
||||
|
||||
tempAwg.Cfg.MagicHeaders.Values[3] = transportMagicHeader
|
||||
tempAwg.Cfg.IsSet = true
|
||||
case "i1", "i2", "i3", "i4", "i5":
|
||||
if len(value) == 0 {
|
||||
device.log.Verbosef("UAPI: received empty %s", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
generators, err := awg.ParseTagJunkGenerator(key, value)
|
||||
case "i1":
|
||||
chain, err := newObfChain(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse I1: %w", err)
|
||||
}
|
||||
device.log.Verbosef("UAPI: Updating %s", key)
|
||||
tempAwg.HandshakeHandler.SpecialJunk.AppendGenerator(generators)
|
||||
tempAwg.HandshakeHandler.IsSet = true
|
||||
device.ipackets[0] = chain
|
||||
|
||||
case "i2":
|
||||
chain, err := newObfChain(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse I2: %w", err)
|
||||
}
|
||||
device.ipackets[1] = chain
|
||||
|
||||
case "i3":
|
||||
chain, err := newObfChain(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse I3: %w", err)
|
||||
}
|
||||
device.ipackets[2] = chain
|
||||
|
||||
case "i4":
|
||||
chain, err := newObfChain(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse I4: %w", err)
|
||||
}
|
||||
device.ipackets[3] = chain
|
||||
|
||||
case "i5":
|
||||
chain, err := newObfChain(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse I5: %w", err)
|
||||
}
|
||||
device.ipackets[4] = chain
|
||||
|
||||
default:
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
||||
}
|
||||
@@ -654,3 +695,49 @@ func (device *Device) IpcHandle(socket net.Conn) {
|
||||
buffered.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
type ipcSetDevice struct {
|
||||
headers struct {
|
||||
init *magicHeader
|
||||
response *magicHeader
|
||||
cookie *magicHeader
|
||||
transport *magicHeader
|
||||
}
|
||||
}
|
||||
|
||||
func (d *ipcSetDevice) mergeWithDevice(device *Device) error {
|
||||
if d.headers.init == nil {
|
||||
d.headers.init = device.headers.init
|
||||
}
|
||||
|
||||
if d.headers.response == nil {
|
||||
d.headers.response = device.headers.response
|
||||
}
|
||||
|
||||
if d.headers.cookie == nil {
|
||||
d.headers.cookie = device.headers.cookie
|
||||
}
|
||||
|
||||
if d.headers.transport == nil {
|
||||
d.headers.transport = device.headers.transport
|
||||
}
|
||||
|
||||
headers := []*magicHeader{d.headers.init, d.headers.response, d.headers.cookie, d.headers.transport}
|
||||
for i := 0; i < len(headers); i++ {
|
||||
for j := i + 1; j < len(headers); j++ {
|
||||
left := headers[i]
|
||||
right := headers[j]
|
||||
|
||||
if left.start <= right.end && right.start <= left.end {
|
||||
return errors.New("headers must not overlap")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
device.headers.init = d.headers.init
|
||||
device.headers.response = d.headers.response
|
||||
device.headers.cookie = d.headers.cookie
|
||||
device.headers.transport = d.headers.transport
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user