chore: merge upstream

This commit is contained in:
Yaroslav Gurov
2026-02-19 15:03:12 +00:00
156 changed files with 8837 additions and 2351 deletions

View File

@@ -34,6 +34,22 @@ jobs:
if: steps.check-assets.outputs.missing == 'true'
run: sleep 90
check-proto:
runs-on: ubuntu-latest
steps:
- name: Checkout codebase
uses: actions/checkout@v6
- name: Check Proto Version Header
run: |
head -n 4 core/config.pb.go > ref.txt
find . -name "*.pb.go" ! -name "*_grpc.pb.go" -print0 | while IFS= read -r -d '' file; do
if ! cmp -s ref.txt <(head -n 4 "$file"); then
echo "Error: Header mismatch in $file"
head -n 4 "$file"
exit 1
fi
done
test:
needs: check-assets
permissions:

View File

@@ -12,9 +12,11 @@ import (
"sync"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/app/router"
"github.com/amnezia-vpn/amnezia-xray-core/common"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
"github.com/amnezia-vpn/amnezia-xray-core/common/net"
"github.com/amnezia-vpn/amnezia-xray-core/common/platform"
"github.com/amnezia-vpn/amnezia-xray-core/common/session"
"github.com/amnezia-vpn/amnezia-xray-core/common/strmatcher"
"github.com/amnezia-vpn/amnezia-xray-core/features/dns"
@@ -83,9 +85,31 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
return nil, errors.New("unexpected query strategy ", config.QueryStrategy)
}
hosts, err := NewStaticHosts(config.StaticHosts)
if err != nil {
return nil, errors.New("failed to create hosts").Base(err)
var hosts *StaticHosts
mphLoaded := false
domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" })
if domainMatcherPath != "" {
if f, err := os.Open(domainMatcherPath); err == nil {
defer f.Close()
if m, err := router.LoadGeoSiteMatcher(f, "HOSTS"); err == nil {
f.Seek(0, 0)
if hostIPs, err := router.LoadGeoSiteHosts(f); err == nil {
if sh, err := NewStaticHostsFromCache(m, hostIPs); err == nil {
hosts = sh
mphLoaded = true
errors.LogDebug(ctx, "MphDomainMatcher loaded from cache for DNS hosts, size: ", sh.matchers.Size())
}
}
}
}
}
if !mphLoaded {
sh, err := NewStaticHosts(config.StaticHosts)
if err != nil {
return nil, errors.New("failed to create hosts").Base(err)
}
hosts = sh
}
var clients []*Client

View File

@@ -14,7 +14,7 @@ import (
// StaticHosts represents static domain-ip mapping in DNS server.
type StaticHosts struct {
ips [][]net.Address
matchers *strmatcher.MatcherGroup
matchers strmatcher.IndexMatcher
}
// NewStaticHosts creates a new StaticHosts instance.
@@ -124,3 +124,50 @@ func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) (
func (h *StaticHosts) Lookup(domain string, option dns.IPOption) ([]net.Address, error) {
return h.lookup(domain, option, 5)
}
func NewStaticHostsFromCache(matcher strmatcher.IndexMatcher, hostIPs map[string][]string) (*StaticHosts, error) {
sh := &StaticHosts{
ips: make([][]net.Address, matcher.Size()+1),
matchers: matcher,
}
order := hostIPs["_ORDER"]
var offset uint32
img, ok := matcher.(*strmatcher.IndexMatcherGroup)
if !ok {
// Single matcher (e.g. only manual or only one geosite)
if len(order) > 0 {
pattern := order[0]
ips := parseIPs(hostIPs[pattern])
for i := uint32(1); i <= matcher.Size(); i++ {
sh.ips[i] = ips
}
}
return sh, nil
}
for i, m := range img.Matchers {
if i < len(order) {
pattern := order[i]
ips := parseIPs(hostIPs[pattern])
for j := uint32(1); j <= m.Size(); j++ {
sh.ips[offset+j] = ips
}
offset += m.Size()
}
}
return sh, nil
}
func parseIPs(raw []string) []net.Address {
addrs := make([]net.Address, 0, len(raw))
for _, s := range raw {
if len(s) > 1 && s[0] == '#' {
rcode, _ := strconv.Atoi(s[1:])
addrs = append(addrs, dns.RCodeError(rcode))
} else {
addrs = append(addrs, net.ParseAddress(s))
}
}
return addrs
}

View File

@@ -1,9 +1,11 @@
package dns_test
import (
"bytes"
"testing"
. "github.com/amnezia-vpn/amnezia-xray-core/app/dns"
"github.com/amnezia-vpn/amnezia-xray-core/app/router"
"github.com/amnezia-vpn/amnezia-xray-core/common"
"github.com/amnezia-vpn/amnezia-xray-core/common/net"
"github.com/amnezia-vpn/amnezia-xray-core/features/dns"
@@ -130,3 +132,57 @@ func TestStaticHosts(t *testing.T) {
}
}
}
func TestStaticHostsFromCache(t *testing.T) {
sites := []*router.GeoSite{
{
CountryCode: "cloudflare-dns.com",
Domain: []*router.Domain{
{Type: router.Domain_Full, Value: "example.com"},
},
},
{
CountryCode: "geosite:cn",
Domain: []*router.Domain{
{Type: router.Domain_Domain, Value: "baidu.cn"},
},
},
}
deps := map[string][]string{
"HOSTS": {"cloudflare-dns.com", "geosite:cn"},
}
hostIPs := map[string][]string{
"cloudflare-dns.com": {"1.1.1.1"},
"geosite:cn": {"2.2.2.2"},
"_ORDER": {"cloudflare-dns.com", "geosite:cn"},
}
var buf bytes.Buffer
err := router.SerializeGeoSiteList(sites, deps, hostIPs, &buf)
common.Must(err)
// Load matcher
m, err := router.LoadGeoSiteMatcher(bytes.NewReader(buf.Bytes()), "HOSTS")
common.Must(err)
// Load hostIPs
f := bytes.NewReader(buf.Bytes())
hips, err := router.LoadGeoSiteHosts(f)
common.Must(err)
hosts, err := NewStaticHostsFromCache(m, hips)
common.Must(err)
{
ips, _ := hosts.Lookup("example.com", dns.IPOption{IPv4Enable: true})
if len(ips) != 1 || ips[0].String() != "1.1.1.1" {
t.Error("failed to lookup example.com from cache")
}
}
{
ips, _ := hosts.Lookup("baidu.cn", dns.IPOption{IPv4Enable: true})
if len(ips) != 1 || ips[0].String() != "2.2.2.2" {
t.Error("failed to lookup baidu.cn from cache deps")
}
}
}

View File

@@ -10,6 +10,8 @@ import (
"github.com/amnezia-vpn/amnezia-xray-core/app/router"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
"github.com/amnezia-vpn/amnezia-xray-core/common/net"
"github.com/amnezia-vpn/amnezia-xray-core/common/platform"
"github.com/amnezia-vpn/amnezia-xray-core/common/platform/filesystem"
"github.com/amnezia-vpn/amnezia-xray-core/common/session"
"github.com/amnezia-vpn/amnezia-xray-core/common/strmatcher"
"github.com/amnezia-vpn/amnezia-xray-core/core"
@@ -17,6 +19,18 @@ import (
"github.com/amnezia-vpn/amnezia-xray-core/features/routing"
)
type mphMatcherWrapper struct {
m strmatcher.IndexMatcher
}
func (w *mphMatcherWrapper) Match(s string) bool {
return w.m.Match(s) != nil
}
func (w *mphMatcherWrapper) String() string {
return "mph-matcher"
}
// Server is the interface for Name Server.
type Server interface {
// Name of the Client.
@@ -132,29 +146,50 @@ func NewClient(
var rules []string
ruleCurr := 0
ruleIter := 0
for i, domain := range ns.PrioritizedDomain {
ns.PrioritizedDomain[i] = nil
domainRule, err := toStrMatcher(domain.Type, domain.Domain)
if err != nil {
errors.LogErrorInner(ctx, err, "failed to create domain matcher, ignore domain rule [type: ", domain.Type, ", domain: ", domain.Domain, "]")
domainRule, _ = toStrMatcher(DomainMatchingType_Full, "hack.fix.index.for.illegal.domain.rule")
}
originalRuleIdx := ruleCurr
if ruleCurr < len(ns.OriginalRules) {
rule := ns.OriginalRules[ruleCurr]
if ruleCurr >= len(rules) {
rules = append(rules, rule.Rule)
// Check if domain matcher cache is provided via environment
domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" })
var mphLoaded bool
if domainMatcherPath != "" && ns.Tag != "" {
f, err := filesystem.NewFileReader(domainMatcherPath)
if err == nil {
defer f.Close()
g, err := router.LoadGeoSiteMatcher(f, ns.Tag)
if err == nil {
errors.LogDebug(ctx, "MphDomainMatcher loaded from cache for ", ns.Tag, " dns tag)")
updateDomainRule(&mphMatcherWrapper{m: g}, 0, *matcherInfos)
rules = append(rules, "[MPH Cache]")
mphLoaded = true
}
ruleIter++
if ruleIter >= int(rule.Size) {
ruleIter = 0
}
}
if !mphLoaded {
for i, domain := range ns.PrioritizedDomain {
ns.PrioritizedDomain[i] = nil
domainRule, err := toStrMatcher(domain.Type, domain.Domain)
if err != nil {
errors.LogErrorInner(ctx, err, "failed to create domain matcher, ignore domain rule [type: ", domain.Type, ", domain: ", domain.Domain, "]")
domainRule, _ = toStrMatcher(DomainMatchingType_Full, "hack.fix.index.for.illegal.domain.rule")
}
originalRuleIdx := ruleCurr
if ruleCurr < len(ns.OriginalRules) {
rule := ns.OriginalRules[ruleCurr]
if ruleCurr >= len(rules) {
rules = append(rules, rule.Rule)
}
ruleIter++
if ruleIter >= int(rule.Size) {
ruleIter = 0
ruleCurr++
}
} else { // No original rule, generate one according to current domain matcher (majorly for compatibility with tests)
rules = append(rules, domainRule.String())
ruleCurr++
}
} else { // No original rule, generate one according to current domain matcher (majorly for compatibility with tests)
rules = append(rules, domainRule.String())
ruleCurr++
updateDomainRule(domainRule, originalRuleIdx, *matcherInfos)
}
updateDomainRule(domainRule, originalRuleIdx, *matcherInfos)
}
ns.PrioritizedDomain = nil
runtime.GC()

View File

@@ -8,7 +8,6 @@ import (
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/common"
@@ -19,6 +18,7 @@ import (
"github.com/amnezia-vpn/amnezia-xray-core/common/net/cnc"
"github.com/amnezia-vpn/amnezia-xray-core/common/protocol/dns"
"github.com/amnezia-vpn/amnezia-xray-core/common/session"
"github.com/amnezia-vpn/amnezia-xray-core/common/utils"
dns_feature "github.com/amnezia-vpn/amnezia-xray-core/features/dns"
"github.com/amnezia-vpn/amnezia-xray-core/features/routing"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet"
@@ -214,8 +214,8 @@ func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte,
req.Header.Add("Accept", "application/dns-message")
req.Header.Add("Content-Type", "application/dns-message")
req.Header.Set("X-Padding", strings.Repeat("X", int(crypto.RandBetween(100, 1000))))
req.Header.Set("User-Agent", utils.ChromeUA)
req.Header.Set("X-Padding", utils.H2Base62Pad(crypto.RandBetween(100, 1000)))
hc := s.httpClient

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v5.28.2
// - protoc-gen-go-grpc v1.6.1
// - protoc v3.21.12
// source: app/log/command/config.proto
package command
@@ -63,7 +63,7 @@ type LoggerServiceServer interface {
type UnimplementedLoggerServiceServer struct{}
func (UnimplementedLoggerServiceServer) RestartLogger(context.Context, *RestartLoggerRequest) (*RestartLoggerResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method RestartLogger not implemented")
return nil, status.Error(codes.Unimplemented, "method RestartLogger not implemented")
}
func (UnimplementedLoggerServiceServer) mustEmbedUnimplementedLoggerServiceServer() {}
func (UnimplementedLoggerServiceServer) testEmbeddedByValue() {}
@@ -76,7 +76,7 @@ type UnsafeLoggerServiceServer interface {
}
func RegisterLoggerServiceServer(s grpc.ServiceRegistrar, srv LoggerServiceServer) {
// If the following call pancis, it indicates UnimplementedLoggerServiceServer was
// If the following call panics, it indicates UnimplementedLoggerServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/amnezia-vpn/amnezia-xray-core/common/net"
"github.com/amnezia-vpn/amnezia-xray-core/common/utils"
"github.com/amnezia-vpn/amnezia-xray-core/features/routing"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/tagged"
)
@@ -61,6 +62,7 @@ func (s *pingClient) MeasureDelay(httpMethod string) (time.Duration, error) {
if err != nil {
return rttFailed, err
}
req.Header.Set("User-Agent", utils.ChromeUA)
start := time.Now()
resp, err := s.httpClient.Do(req)

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v5.28.2
// - protoc-gen-go-grpc v1.6.1
// - protoc v3.21.12
// source: app/observatory/command/command.proto
package command
@@ -63,7 +63,7 @@ type ObservatoryServiceServer interface {
type UnimplementedObservatoryServiceServer struct{}
func (UnimplementedObservatoryServiceServer) GetOutboundStatus(context.Context, *GetOutboundStatusRequest) (*GetOutboundStatusResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetOutboundStatus not implemented")
return nil, status.Error(codes.Unimplemented, "method GetOutboundStatus not implemented")
}
func (UnimplementedObservatoryServiceServer) mustEmbedUnimplementedObservatoryServiceServer() {}
func (UnimplementedObservatoryServiceServer) testEmbeddedByValue() {}
@@ -76,7 +76,7 @@ type UnsafeObservatoryServiceServer interface {
}
func RegisterObservatoryServiceServer(s grpc.ServiceRegistrar, srv ObservatoryServiceServer) {
// If the following call pancis, it indicates UnimplementedObservatoryServiceServer was
// If the following call panics, it indicates UnimplementedObservatoryServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.

View File

@@ -15,6 +15,7 @@ import (
"github.com/amnezia-vpn/amnezia-xray-core/common/session"
"github.com/amnezia-vpn/amnezia-xray-core/common/signal/done"
"github.com/amnezia-vpn/amnezia-xray-core/common/task"
"github.com/amnezia-vpn/amnezia-xray-core/common/utils"
"github.com/amnezia-vpn/amnezia-xray-core/core"
"github.com/amnezia-vpn/amnezia-xray-core/features/extension"
"github.com/amnezia-vpn/amnezia-xray-core/features/outbound"
@@ -162,7 +163,9 @@ func (o *Observer) probe(outbound string) ProbeResult {
if o.config.ProbeUrl != "" {
probeURL = o.config.ProbeUrl
}
response, err := httpClient.Get(probeURL)
req, _ := http.NewRequest(http.MethodGet, probeURL, nil)
req.Header.Set("User-Agent", utils.ChromeUA)
response, err := httpClient.Do(req)
if err != nil {
return errors.New("outbound failed to relay connection").Base(err)
}

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v5.28.2
// - protoc-gen-go-grpc v1.6.1
// - protoc v3.21.12
// source: app/proxyman/command/command.proto
package command
@@ -180,34 +180,34 @@ type HandlerServiceServer interface {
type UnimplementedHandlerServiceServer struct{}
func (UnimplementedHandlerServiceServer) AddInbound(context.Context, *AddInboundRequest) (*AddInboundResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method AddInbound not implemented")
return nil, status.Error(codes.Unimplemented, "method AddInbound not implemented")
}
func (UnimplementedHandlerServiceServer) RemoveInbound(context.Context, *RemoveInboundRequest) (*RemoveInboundResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method RemoveInbound not implemented")
return nil, status.Error(codes.Unimplemented, "method RemoveInbound not implemented")
}
func (UnimplementedHandlerServiceServer) AlterInbound(context.Context, *AlterInboundRequest) (*AlterInboundResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method AlterInbound not implemented")
return nil, status.Error(codes.Unimplemented, "method AlterInbound not implemented")
}
func (UnimplementedHandlerServiceServer) ListInbounds(context.Context, *ListInboundsRequest) (*ListInboundsResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ListInbounds not implemented")
return nil, status.Error(codes.Unimplemented, "method ListInbounds not implemented")
}
func (UnimplementedHandlerServiceServer) GetInboundUsers(context.Context, *GetInboundUserRequest) (*GetInboundUserResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetInboundUsers not implemented")
return nil, status.Error(codes.Unimplemented, "method GetInboundUsers not implemented")
}
func (UnimplementedHandlerServiceServer) GetInboundUsersCount(context.Context, *GetInboundUserRequest) (*GetInboundUsersCountResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetInboundUsersCount not implemented")
return nil, status.Error(codes.Unimplemented, "method GetInboundUsersCount not implemented")
}
func (UnimplementedHandlerServiceServer) AddOutbound(context.Context, *AddOutboundRequest) (*AddOutboundResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method AddOutbound not implemented")
return nil, status.Error(codes.Unimplemented, "method AddOutbound not implemented")
}
func (UnimplementedHandlerServiceServer) RemoveOutbound(context.Context, *RemoveOutboundRequest) (*RemoveOutboundResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method RemoveOutbound not implemented")
return nil, status.Error(codes.Unimplemented, "method RemoveOutbound not implemented")
}
func (UnimplementedHandlerServiceServer) AlterOutbound(context.Context, *AlterOutboundRequest) (*AlterOutboundResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method AlterOutbound not implemented")
return nil, status.Error(codes.Unimplemented, "method AlterOutbound not implemented")
}
func (UnimplementedHandlerServiceServer) ListOutbounds(context.Context, *ListOutboundsRequest) (*ListOutboundsResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ListOutbounds not implemented")
return nil, status.Error(codes.Unimplemented, "method ListOutbounds not implemented")
}
func (UnimplementedHandlerServiceServer) mustEmbedUnimplementedHandlerServiceServer() {}
func (UnimplementedHandlerServiceServer) testEmbeddedByValue() {}
@@ -220,7 +220,7 @@ type UnsafeHandlerServiceServer interface {
}
func RegisterHandlerServiceServer(s grpc.ServiceRegistrar, srv HandlerServiceServer) {
// If the following call pancis, it indicates UnimplementedHandlerServiceServer was
// If the following call panics, it indicates UnimplementedHandlerServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.6.0
// - protoc v6.33.2
// - protoc-gen-go-grpc v1.6.1
// - protoc v3.21.12
// source: app/router/command/command.proto
package command

View File

@@ -2,6 +2,7 @@ package router
import (
"context"
"io"
"os"
"path/filepath"
"regexp"
@@ -52,7 +53,34 @@ var matcherTypeMap = map[Domain_Type]strmatcher.Type{
}
type DomainMatcher struct {
matchers strmatcher.IndexMatcher
Matchers strmatcher.IndexMatcher
}
func SerializeDomainMatcher(domains []*Domain, w io.Writer) error {
g := strmatcher.NewMphMatcherGroup()
for _, d := range domains {
matcherType, f := matcherTypeMap[d.Type]
if !f {
continue
}
_, err := g.AddPattern(d.Value, matcherType)
if err != nil {
return err
}
}
g.Build()
// serialize
return g.Serialize(w)
}
func NewDomainMatcherFromBuffer(data []byte) (*strmatcher.MphMatcherGroup, error) {
matcher, err := strmatcher.NewMphMatcherGroupFromBuffer(data)
if err != nil {
return nil, err
}
return matcher, nil
}
func NewMphMatcherGroup(domains []*Domain) (*DomainMatcher, error) {
@@ -72,12 +100,12 @@ func NewMphMatcherGroup(domains []*Domain) (*DomainMatcher, error) {
}
g.Build()
return &DomainMatcher{
matchers: g,
Matchers: g,
}, nil
}
func (m *DomainMatcher) ApplyDomain(domain string) bool {
return len(m.matchers.Match(strings.ToLower(domain))) > 0
return len(m.Matchers.Match(strings.ToLower(domain))) > 0
}
// Apply implements Condition.

View File

@@ -0,0 +1,167 @@
package router_test
import (
"bytes"
"os"
"path/filepath"
"testing"
"github.com/amnezia-vpn/amnezia-xray-core/app/router"
"github.com/amnezia-vpn/amnezia-xray-core/common/platform/filesystem"
"github.com/stretchr/testify/require"
)
func TestDomainMatcherSerialization(t *testing.T) {
domains := []*router.Domain{
{Type: router.Domain_Domain, Value: "google.com"},
{Type: router.Domain_Domain, Value: "v2ray.com"},
{Type: router.Domain_Full, Value: "full.example.com"},
}
var buf bytes.Buffer
if err := router.SerializeDomainMatcher(domains, &buf); err != nil {
t.Fatalf("Serialize failed: %v", err)
}
matcher, err := router.NewDomainMatcherFromBuffer(buf.Bytes())
if err != nil {
t.Fatalf("Deserialize failed: %v", err)
}
dMatcher := &router.DomainMatcher{
Matchers: matcher,
}
testCases := []struct {
Input string
Match bool
}{
{"google.com", true},
{"maps.google.com", true},
{"v2ray.com", true},
{"full.example.com", true},
{"example.com", false},
}
for _, tc := range testCases {
if res := dMatcher.ApplyDomain(tc.Input); res != tc.Match {
t.Errorf("Match(%s) = %v, want %v", tc.Input, res, tc.Match)
}
}
}
func TestGeoSiteSerialization(t *testing.T) {
sites := []*router.GeoSite{
{
CountryCode: "CN",
Domain: []*router.Domain{
{Type: router.Domain_Domain, Value: "baidu.cn"},
{Type: router.Domain_Domain, Value: "qq.com"},
},
},
{
CountryCode: "US",
Domain: []*router.Domain{
{Type: router.Domain_Domain, Value: "google.com"},
{Type: router.Domain_Domain, Value: "facebook.com"},
},
},
}
var buf bytes.Buffer
if err := router.SerializeGeoSiteList(sites, nil, nil, &buf); err != nil {
t.Fatalf("SerializeGeoSiteList failed: %v", err)
}
tmp := t.TempDir()
path := filepath.Join(tmp, "matcher.cache")
f, err := os.Create(path)
require.NoError(t, err)
_, err = f.Write(buf.Bytes())
require.NoError(t, err)
f.Close()
f, err = os.Open(path)
require.NoError(t, err)
defer f.Close()
require.NoError(t, err)
data, _ := filesystem.ReadFile(path)
// cn
gp, err := router.LoadGeoSiteMatcher(bytes.NewReader(data), "CN")
if err != nil {
t.Fatalf("LoadGeoSiteMatcher(CN) failed: %v", err)
}
cnMatcher := &router.DomainMatcher{
Matchers: gp,
}
if !cnMatcher.ApplyDomain("baidu.cn") {
t.Error("CN matcher should match baidu.cn")
}
if cnMatcher.ApplyDomain("google.com") {
t.Error("CN matcher should NOT match google.com")
}
// us
gp, err = router.LoadGeoSiteMatcher(bytes.NewReader(data), "US")
if err != nil {
t.Fatalf("LoadGeoSiteMatcher(US) failed: %v", err)
}
usMatcher := &router.DomainMatcher{
Matchers: gp,
}
if !usMatcher.ApplyDomain("google.com") {
t.Error("US matcher should match google.com")
}
if usMatcher.ApplyDomain("baidu.cn") {
t.Error("US matcher should NOT match baidu.cn")
}
// unknown
_, err = router.LoadGeoSiteMatcher(bytes.NewReader(data), "unknown")
if err == nil {
t.Error("LoadGeoSiteMatcher(unknown) should fail")
}
}
func TestGeoSiteSerializationWithDeps(t *testing.T) {
sites := []*router.GeoSite{
{
CountryCode: "geosite:cn",
Domain: []*router.Domain{
{Type: router.Domain_Domain, Value: "baidu.cn"},
},
},
{
CountryCode: "geosite:google@cn",
Domain: []*router.Domain{
{Type: router.Domain_Domain, Value: "google.cn"},
},
},
{
CountryCode: "rule-1",
Domain: []*router.Domain{
{Type: router.Domain_Domain, Value: "google.com"},
},
},
}
deps := map[string][]string{
"rule-1": {"geosite:cn", "geosite:google@cn"},
}
var buf bytes.Buffer
err := router.SerializeGeoSiteList(sites, deps, nil, &buf)
require.NoError(t, err)
matcher, err := router.LoadGeoSiteMatcher(bytes.NewReader(buf.Bytes()), "rule-1")
require.NoError(t, err)
require.True(t, matcher.Match("google.com") != nil)
require.True(t, matcher.Match("baidu.cn") != nil)
require.True(t, matcher.Match("google.cn") != nil)
}

View File

@@ -7,6 +7,8 @@ import (
"strings"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
"github.com/amnezia-vpn/amnezia-xray-core/common/platform"
"github.com/amnezia-vpn/amnezia-xray-core/common/platform/filesystem"
"github.com/amnezia-vpn/amnezia-xray-core/features/outbound"
"github.com/amnezia-vpn/amnezia-xray-core/features/routing"
)
@@ -105,11 +107,25 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
}
if len(rr.Domain) > 0 {
matcher, err := NewMphMatcherGroup(rr.Domain)
if err != nil {
return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err)
var matcher *DomainMatcher
var err error
// Check if domain matcher cache is provided via environment
domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" })
if domainMatcherPath != "" {
matcher, err = GetDomainMatcherWithRuleTag(domainMatcherPath, rr.RuleTag)
if err != nil {
return nil, errors.New("failed to build domain condition from cached MphDomainMatcher").Base(err)
}
errors.LogDebug(context.Background(), "MphDomainMatcher loaded from cache for ", rr.RuleTag, " rule tag)")
} else {
matcher, err = NewMphMatcherGroup(rr.Domain)
if err != nil {
return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err)
}
errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)")
}
errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)")
conds.Add(matcher)
rr.Domain = nil
runtime.GC()
@@ -172,3 +188,20 @@ func (br *BalancingRule) Build(ohm outbound.Manager, dispatcher routing.Dispatch
return nil, errors.New("unrecognized balancer type")
}
}
func GetDomainMatcherWithRuleTag(domainMatcherPath string, ruleTag string) (*DomainMatcher, error) {
f, err := filesystem.NewFileReader(domainMatcherPath)
if err != nil {
return nil, errors.New("failed to load file: ", domainMatcherPath).Base(err)
}
defer f.Close()
g, err := LoadGeoSiteMatcher(f, ruleTag)
if err != nil {
return nil, errors.New("failed to load file:", domainMatcherPath).Base(err)
}
return &DomainMatcher{
Matchers: g,
}, nil
}

View File

@@ -0,0 +1,100 @@
package router
import (
"encoding/gob"
"errors"
"io"
"runtime"
"github.com/amnezia-vpn/amnezia-xray-core/common/strmatcher"
)
type geoSiteListGob struct {
Sites map[string][]byte
Deps map[string][]string
Hosts map[string][]string
}
func SerializeGeoSiteList(sites []*GeoSite, deps map[string][]string, hosts map[string][]string, w io.Writer) error {
data := geoSiteListGob{
Sites: make(map[string][]byte),
Deps: deps,
Hosts: hosts,
}
for _, site := range sites {
if site == nil {
continue
}
var buf bytesWriter
if err := SerializeDomainMatcher(site.Domain, &buf); err != nil {
return err
}
data.Sites[site.CountryCode] = buf.Bytes()
}
return gob.NewEncoder(w).Encode(data)
}
type bytesWriter struct {
data []byte
}
func (w *bytesWriter) Write(p []byte) (n int, err error) {
w.data = append(w.data, p...)
return len(p), nil
}
func (w *bytesWriter) Bytes() []byte {
return w.data
}
func LoadGeoSiteMatcher(r io.Reader, countryCode string) (strmatcher.IndexMatcher, error) {
var data geoSiteListGob
if err := gob.NewDecoder(r).Decode(&data); err != nil {
return nil, err
}
return loadWithDeps(&data, countryCode, make(map[string]bool))
}
func loadWithDeps(data *geoSiteListGob, code string, visited map[string]bool) (strmatcher.IndexMatcher, error) {
if visited[code] {
return nil, errors.New("cyclic dependency")
}
visited[code] = true
var matchers []strmatcher.IndexMatcher
if siteData, ok := data.Sites[code]; ok {
m, err := NewDomainMatcherFromBuffer(siteData)
if err == nil {
matchers = append(matchers, m)
}
}
if deps, ok := data.Deps[code]; ok {
for _, dep := range deps {
m, err := loadWithDeps(data, dep, visited)
if err == nil {
matchers = append(matchers, m)
}
}
}
if len(matchers) == 0 {
return nil, errors.New("matcher not found for: " + code)
}
if len(matchers) == 1 {
return matchers[0], nil
}
runtime.GC()
return &strmatcher.IndexMatcherGroup{Matchers: matchers}, nil
}
func LoadGeoSiteHosts(r io.Reader) (map[string][]string, error) {
var data geoSiteListGob
if err := gob.NewDecoder(r).Decode(&data); err != nil {
return nil, err
}
return data.Hosts, nil
}

View File

@@ -1,8 +1,8 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v6.32.0
// source: command.proto
// - protoc-gen-go-grpc v1.6.1
// - protoc v3.21.12
// source: app/stats/command/command.proto
package command
@@ -128,22 +128,22 @@ type StatsServiceServer interface {
type UnimplementedStatsServiceServer struct{}
func (UnimplementedStatsServiceServer) GetStats(context.Context, *GetStatsRequest) (*GetStatsResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetStats not implemented")
return nil, status.Error(codes.Unimplemented, "method GetStats not implemented")
}
func (UnimplementedStatsServiceServer) GetStatsOnline(context.Context, *GetStatsRequest) (*GetStatsResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetStatsOnline not implemented")
return nil, status.Error(codes.Unimplemented, "method GetStatsOnline not implemented")
}
func (UnimplementedStatsServiceServer) QueryStats(context.Context, *QueryStatsRequest) (*QueryStatsResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method QueryStats not implemented")
return nil, status.Error(codes.Unimplemented, "method QueryStats not implemented")
}
func (UnimplementedStatsServiceServer) GetSysStats(context.Context, *SysStatsRequest) (*SysStatsResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetSysStats not implemented")
return nil, status.Error(codes.Unimplemented, "method GetSysStats not implemented")
}
func (UnimplementedStatsServiceServer) GetStatsOnlineIpList(context.Context, *GetStatsRequest) (*GetStatsOnlineIpListResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetStatsOnlineIpList not implemented")
return nil, status.Error(codes.Unimplemented, "method GetStatsOnlineIpList not implemented")
}
func (UnimplementedStatsServiceServer) GetAllOnlineUsers(context.Context, *GetAllOnlineUsersRequest) (*GetAllOnlineUsersResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetAllOnlineUsers not implemented")
return nil, status.Error(codes.Unimplemented, "method GetAllOnlineUsers not implemented")
}
func (UnimplementedStatsServiceServer) mustEmbedUnimplementedStatsServiceServer() {}
func (UnimplementedStatsServiceServer) testEmbeddedByValue() {}
@@ -156,7 +156,7 @@ type UnsafeStatsServiceServer interface {
}
func RegisterStatsServiceServer(s grpc.ServiceRegistrar, srv StatsServiceServer) {
// If the following call pancis, it indicates UnimplementedStatsServiceServer was
// If the following call panics, it indicates UnimplementedStatsServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
@@ -307,5 +307,5 @@ var StatsService_ServiceDesc = grpc.ServiceDesc{
},
},
Streams: []grpc.StreamDesc{},
Metadata: "command.proto",
Metadata: "app/stats/command/command.proto",
}

View File

@@ -24,6 +24,8 @@ const (
XUDPBaseKey = "xray.xudp.basekey"
TunFdKey = "xray.tun.fd"
MphCachePath = "xray.mph.cache"
)
type EnvFlag struct {

View File

@@ -7,8 +7,8 @@ import (
const validCharCount = 53
type MatchType struct {
matchType Type
exist bool
Type Type
Exist bool
}
const (
@@ -17,23 +17,23 @@ const (
)
type Edge struct {
edgeType bool
nextNode int
Type bool
NextNode int
}
type ACAutomaton struct {
trie [][validCharCount]Edge
fail []int
exists []MatchType
count int
Trie [][validCharCount]Edge
Fail []int
Exists []MatchType
Count int
}
func newNode() [validCharCount]Edge {
var s [validCharCount]Edge
for i := range s {
s[i] = Edge{
edgeType: FailEdge,
nextNode: 0,
Type: FailEdge,
NextNode: 0,
}
}
return s
@@ -123,11 +123,11 @@ var char2Index = []int{
func NewACAutomaton() *ACAutomaton {
ac := new(ACAutomaton)
ac.trie = append(ac.trie, newNode())
ac.fail = append(ac.fail, 0)
ac.exists = append(ac.exists, MatchType{
matchType: Full,
exist: false,
ac.Trie = append(ac.Trie, newNode())
ac.Fail = append(ac.Fail, 0)
ac.Exists = append(ac.Exists, MatchType{
Type: Full,
Exist: false,
})
return ac
}
@@ -136,53 +136,53 @@ func (ac *ACAutomaton) Add(domain string, t Type) {
node := 0
for i := len(domain) - 1; i >= 0; i-- {
idx := char2Index[domain[i]]
if ac.trie[node][idx].nextNode == 0 {
ac.count++
if len(ac.trie) < ac.count+1 {
ac.trie = append(ac.trie, newNode())
ac.fail = append(ac.fail, 0)
ac.exists = append(ac.exists, MatchType{
matchType: Full,
exist: false,
if ac.Trie[node][idx].NextNode == 0 {
ac.Count++
if len(ac.Trie) < ac.Count+1 {
ac.Trie = append(ac.Trie, newNode())
ac.Fail = append(ac.Fail, 0)
ac.Exists = append(ac.Exists, MatchType{
Type: Full,
Exist: false,
})
}
ac.trie[node][idx] = Edge{
edgeType: TrieEdge,
nextNode: ac.count,
ac.Trie[node][idx] = Edge{
Type: TrieEdge,
NextNode: ac.Count,
}
}
node = ac.trie[node][idx].nextNode
node = ac.Trie[node][idx].NextNode
}
ac.exists[node] = MatchType{
matchType: t,
exist: true,
ac.Exists[node] = MatchType{
Type: t,
Exist: true,
}
switch t {
case Domain:
ac.exists[node] = MatchType{
matchType: Full,
exist: true,
ac.Exists[node] = MatchType{
Type: Full,
Exist: true,
}
idx := char2Index['.']
if ac.trie[node][idx].nextNode == 0 {
ac.count++
if len(ac.trie) < ac.count+1 {
ac.trie = append(ac.trie, newNode())
ac.fail = append(ac.fail, 0)
ac.exists = append(ac.exists, MatchType{
matchType: Full,
exist: false,
if ac.Trie[node][idx].NextNode == 0 {
ac.Count++
if len(ac.Trie) < ac.Count+1 {
ac.Trie = append(ac.Trie, newNode())
ac.Fail = append(ac.Fail, 0)
ac.Exists = append(ac.Exists, MatchType{
Type: Full,
Exist: false,
})
}
ac.trie[node][idx] = Edge{
edgeType: TrieEdge,
nextNode: ac.count,
ac.Trie[node][idx] = Edge{
Type: TrieEdge,
NextNode: ac.Count,
}
}
node = ac.trie[node][idx].nextNode
ac.exists[node] = MatchType{
matchType: t,
exist: true,
node = ac.Trie[node][idx].NextNode
ac.Exists[node] = MatchType{
Type: t,
Exist: true,
}
default:
break
@@ -192,8 +192,8 @@ func (ac *ACAutomaton) Add(domain string, t Type) {
func (ac *ACAutomaton) Build() {
queue := list.New()
for i := 0; i < validCharCount; i++ {
if ac.trie[0][i].nextNode != 0 {
queue.PushBack(ac.trie[0][i])
if ac.Trie[0][i].NextNode != 0 {
queue.PushBack(ac.Trie[0][i])
}
}
for {
@@ -201,16 +201,16 @@ func (ac *ACAutomaton) Build() {
if front == nil {
break
} else {
node := front.Value.(Edge).nextNode
node := front.Value.(Edge).NextNode
queue.Remove(front)
for i := 0; i < validCharCount; i++ {
if ac.trie[node][i].nextNode != 0 {
ac.fail[ac.trie[node][i].nextNode] = ac.trie[ac.fail[node]][i].nextNode
queue.PushBack(ac.trie[node][i])
if ac.Trie[node][i].NextNode != 0 {
ac.Fail[ac.Trie[node][i].NextNode] = ac.Trie[ac.Fail[node]][i].NextNode
queue.PushBack(ac.Trie[node][i])
} else {
ac.trie[node][i] = Edge{
edgeType: FailEdge,
nextNode: ac.trie[ac.fail[node]][i].nextNode,
ac.Trie[node][i] = Edge{
Type: FailEdge,
NextNode: ac.Trie[ac.Fail[node]][i].NextNode,
}
}
}
@@ -230,9 +230,9 @@ func (ac *ACAutomaton) Match(s string) bool {
return false
}
idx := char2Index[chr]
fullMatch = fullMatch && ac.trie[node][idx].edgeType
node = ac.trie[node][idx].nextNode
switch ac.exists[node].matchType {
fullMatch = fullMatch && ac.Trie[node][idx].Type
node = ac.Trie[node][idx].NextNode
switch ac.Exists[node].Type {
case Substr:
return true
case Domain:
@@ -243,5 +243,5 @@ func (ac *ACAutomaton) Match(s string) bool {
break
}
}
return fullMatch && ac.exists[node].exist
return fullMatch && ac.Exists[node].Exist
}

View File

@@ -39,14 +39,18 @@ func (m domainMatcher) String() string {
return "domain:" + string(m)
}
type regexMatcher struct {
pattern *regexp.Regexp
type RegexMatcher struct {
Pattern string
reg *regexp.Regexp
}
func (m *regexMatcher) Match(s string) bool {
return m.pattern.MatchString(s)
func (m *RegexMatcher) Match(s string) bool {
if m.reg == nil {
m.reg = regexp.MustCompile(m.Pattern)
}
return m.reg.MatchString(s)
}
func (m *regexMatcher) String() string {
return "regexp:" + m.pattern.String()
func (m *RegexMatcher) String() string {
return "regexp:" + m.Pattern
}

View File

@@ -25,40 +25,40 @@ func RollingHash(s string) uint32 {
// 2. `substr` patterns are matched by ac automaton;
// 3. `regex` patterns are matched with the regex library.
type MphMatcherGroup struct {
ac *ACAutomaton
otherMatchers []matcherEntry
rules []string
level0 []uint32
level0Mask int
level1 []uint32
level1Mask int
count uint32
ruleMap *map[string]uint32
Ac *ACAutomaton
OtherMatchers []MatcherEntry
Rules []string
Level0 []uint32
Level0Mask int
Level1 []uint32
Level1Mask int
Count uint32
RuleMap *map[string]uint32
}
func (g *MphMatcherGroup) AddFullOrDomainPattern(pattern string, t Type) {
h := RollingHash(pattern)
switch t {
case Domain:
(*g.ruleMap)["."+pattern] = h*PrimeRK + uint32('.')
(*g.RuleMap)["."+pattern] = h*PrimeRK + uint32('.')
fallthrough
case Full:
(*g.ruleMap)[pattern] = h
(*g.RuleMap)[pattern] = h
default:
}
}
func NewMphMatcherGroup() *MphMatcherGroup {
return &MphMatcherGroup{
ac: nil,
otherMatchers: nil,
rules: nil,
level0: nil,
level0Mask: 0,
level1: nil,
level1Mask: 0,
count: 1,
ruleMap: &map[string]uint32{},
Ac: nil,
OtherMatchers: nil,
Rules: nil,
Level0: nil,
Level0Mask: 0,
Level1: nil,
Level1Mask: 0,
Count: 1,
RuleMap: &map[string]uint32{},
}
}
@@ -66,10 +66,10 @@ func NewMphMatcherGroup() *MphMatcherGroup {
func (g *MphMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) {
switch t {
case Substr:
if g.ac == nil {
g.ac = NewACAutomaton()
if g.Ac == nil {
g.Ac = NewACAutomaton()
}
g.ac.Add(pattern, t)
g.Ac.Add(pattern, t)
case Full, Domain:
pattern = strings.ToLower(pattern)
g.AddFullOrDomainPattern(pattern, t)
@@ -78,39 +78,39 @@ func (g *MphMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) {
if err != nil {
return 0, err
}
g.otherMatchers = append(g.otherMatchers, matcherEntry{
m: &regexMatcher{pattern: r},
id: g.count,
g.OtherMatchers = append(g.OtherMatchers, MatcherEntry{
M: &RegexMatcher{Pattern: pattern, reg: r},
Id: g.Count,
})
default:
panic("Unknown type")
}
return g.count, nil
return g.Count, nil
}
// Build builds a minimal perfect hash table and ac automaton from insert rules
func (g *MphMatcherGroup) Build() {
if g.ac != nil {
g.ac.Build()
if g.Ac != nil {
g.Ac.Build()
}
keyLen := len(*g.ruleMap)
keyLen := len(*g.RuleMap)
if keyLen == 0 {
keyLen = 1
(*g.ruleMap)["empty___"] = RollingHash("empty___")
(*g.RuleMap)["empty___"] = RollingHash("empty___")
}
g.level0 = make([]uint32, nextPow2(keyLen/4))
g.level0Mask = len(g.level0) - 1
g.level1 = make([]uint32, nextPow2(keyLen))
g.level1Mask = len(g.level1) - 1
sparseBuckets := make([][]int, len(g.level0))
g.Level0 = make([]uint32, nextPow2(keyLen/4))
g.Level0Mask = len(g.Level0) - 1
g.Level1 = make([]uint32, nextPow2(keyLen))
g.Level1Mask = len(g.Level1) - 1
sparseBuckets := make([][]int, len(g.Level0))
var ruleIdx int
for rule, hash := range *g.ruleMap {
n := int(hash) & g.level0Mask
g.rules = append(g.rules, rule)
for rule, hash := range *g.RuleMap {
n := int(hash) & g.Level0Mask
g.Rules = append(g.Rules, rule)
sparseBuckets[n] = append(sparseBuckets[n], ruleIdx)
ruleIdx++
}
g.ruleMap = nil
g.RuleMap = nil
var buckets []indexBucket
for n, vals := range sparseBuckets {
if len(vals) > 0 {
@@ -119,7 +119,7 @@ func (g *MphMatcherGroup) Build() {
}
sort.Sort(bySize(buckets))
occ := make([]bool, len(g.level1))
occ := make([]bool, len(g.Level1))
var tmpOcc []int
for _, bucket := range buckets {
seed := uint32(0)
@@ -127,7 +127,7 @@ func (g *MphMatcherGroup) Build() {
findSeed := true
tmpOcc = tmpOcc[:0]
for _, i := range bucket.vals {
n := int(strhashFallback(unsafe.Pointer(&g.rules[i]), uintptr(seed))) & g.level1Mask
n := int(strhashFallback(unsafe.Pointer(&g.Rules[i]), uintptr(seed))) & g.Level1Mask
if occ[n] {
for _, n := range tmpOcc {
occ[n] = false
@@ -138,10 +138,10 @@ func (g *MphMatcherGroup) Build() {
}
occ[n] = true
tmpOcc = append(tmpOcc, n)
g.level1[n] = uint32(i)
g.Level1[n] = uint32(i)
}
if findSeed {
g.level0[bucket.n] = seed
g.Level0[bucket.n] = seed
break
}
}
@@ -159,11 +159,11 @@ func nextPow2(v int) int {
// Lookup searches for s in t and returns its index and whether it was found.
func (g *MphMatcherGroup) Lookup(h uint32, s string) bool {
i0 := int(h) & g.level0Mask
seed := g.level0[i0]
i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.level1Mask
n := g.level1[i1]
return s == g.rules[int(n)]
i0 := int(h) & g.Level0Mask
seed := g.Level0[i0]
i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.Level1Mask
n := g.Level1[i1]
return s == g.Rules[int(n)]
}
// Match implements IndexMatcher.Match.
@@ -183,13 +183,13 @@ func (g *MphMatcherGroup) Match(pattern string) []uint32 {
result = append(result, 1)
return result
}
if g.ac != nil && g.ac.Match(pattern) {
if g.Ac != nil && g.Ac.Match(pattern) {
result = append(result, 1)
return result
}
for _, e := range g.otherMatchers {
if e.m.Match(pattern) {
result = append(result, e.id)
for _, e := range g.OtherMatchers {
if e.M.Match(pattern) {
result = append(result, e.Id)
return result
}
}
@@ -302,3 +302,7 @@ func readUnaligned64(p unsafe.Pointer) uint64 {
q := (*[8]byte)(p)
return uint64(q[0]) | uint64(q[1])<<8 | uint64(q[2])<<16 | uint64(q[3])<<24 | uint64(q[4])<<32 | uint64(q[5])<<40 | uint64(q[6])<<48 | uint64(q[7])<<56
}
func (g *MphMatcherGroup) Size() uint32 {
return g.Count
}

View File

@@ -0,0 +1,47 @@
package strmatcher
import (
"bytes"
"encoding/gob"
"io"
)
func init() {
gob.Register(&RegexMatcher{})
gob.Register(fullMatcher(""))
gob.Register(substrMatcher(""))
gob.Register(domainMatcher(""))
}
func (g *MphMatcherGroup) Serialize(w io.Writer) error {
data := MphMatcherGroup{
Ac: g.Ac,
OtherMatchers: g.OtherMatchers,
Rules: g.Rules,
Level0: g.Level0,
Level0Mask: g.Level0Mask,
Level1: g.Level1,
Level1Mask: g.Level1Mask,
Count: g.Count,
}
return gob.NewEncoder(w).Encode(data)
}
func NewMphMatcherGroupFromBuffer(data []byte) (*MphMatcherGroup, error) {
var gData MphMatcherGroup
if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&gData); err != nil {
return nil, err
}
g := NewMphMatcherGroup()
g.Ac = gData.Ac
g.OtherMatchers = gData.OtherMatchers
g.Rules = gData.Rules
g.Level0 = gData.Level0
g.Level0Mask = gData.Level0Mask
g.Level1 = gData.Level1
g.Level1Mask = gData.Level1Mask
g.Count = gData.Count
return g, nil
}

View File

@@ -41,8 +41,9 @@ func (t Type) New(pattern string) (Matcher, error) {
if err != nil {
return nil, err
}
return &regexMatcher{
pattern: r,
return &RegexMatcher{
Pattern: pattern,
reg: r,
}, nil
default:
return nil, errors.New("unk type")
@@ -53,11 +54,13 @@ func (t Type) New(pattern string) (Matcher, error) {
type IndexMatcher interface {
// Match returns the index of a matcher that matches the input. It returns empty array if no such matcher exists.
Match(input string) []uint32
// Size returns the number of matchers in the group.
Size() uint32
}
type matcherEntry struct {
m Matcher
id uint32
type MatcherEntry struct {
M Matcher
Id uint32
}
// MatcherGroup is an implementation of IndexMatcher.
@@ -66,7 +69,7 @@ type MatcherGroup struct {
count uint32
fullMatcher FullMatcherGroup
domainMatcher DomainMatcherGroup
otherMatchers []matcherEntry
otherMatchers []MatcherEntry
}
// Add adds a new Matcher into the MatcherGroup, and returns its index. The index will never be 0.
@@ -80,9 +83,9 @@ func (g *MatcherGroup) Add(m Matcher) uint32 {
case domainMatcher:
g.domainMatcher.addMatcher(tm, c)
default:
g.otherMatchers = append(g.otherMatchers, matcherEntry{
m: m,
id: c,
g.otherMatchers = append(g.otherMatchers, MatcherEntry{
M: m,
Id: c,
})
}
@@ -95,8 +98,8 @@ func (g *MatcherGroup) Match(pattern string) []uint32 {
result = append(result, g.fullMatcher.Match(pattern)...)
result = append(result, g.domainMatcher.Match(pattern)...)
for _, e := range g.otherMatchers {
if e.m.Match(pattern) {
result = append(result, e.id)
if e.M.Match(pattern) {
result = append(result, e.Id)
}
}
return result
@@ -106,3 +109,33 @@ func (g *MatcherGroup) Match(pattern string) []uint32 {
func (g *MatcherGroup) Size() uint32 {
return g.count
}
type IndexMatcherGroup struct {
Matchers []IndexMatcher
}
func (g *IndexMatcherGroup) Match(input string) []uint32 {
var offset uint32
for _, m := range g.Matchers {
if res := m.Match(input); len(res) > 0 {
if offset == 0 {
return res
}
shifted := make([]uint32, len(res))
for i, id := range res {
shifted[i] = id + offset
}
return shifted
}
offset += m.Size()
}
return nil
}
func (g *IndexMatcherGroup) Size() uint32 {
var count uint32
for _, m := range g.Matchers {
count += m.Size()
}
return count
}

View File

@@ -0,0 +1,17 @@
package utils
import (
"reflect"
"unsafe"
)
// AccessField can used to access unexported field of a struct
// valueType must be the exact type of the field or it will panic
func AccessField[valueType any](obj any, fieldName string) *valueType {
field := reflect.ValueOf(obj).Elem().FieldByName(fieldName)
if field.Type() != reflect.TypeOf(*new(valueType)) {
panic("field type: " + field.Type().String() + ", valueType: " + reflect.TypeOf(*new(valueType)).String())
}
v := (*valueType)(unsafe.Pointer(field.UnsafeAddr()))
return v
}

28
common/utils/browser.go Normal file
View File

@@ -0,0 +1,28 @@
package utils
import (
"math/rand"
"strconv"
"time"
"github.com/klauspost/cpuid/v2"
)
func ChromeVersion() int {
// Use only CPU info as seed for PRNG
seed := int64(cpuid.CPU.Family + cpuid.CPU.Model + cpuid.CPU.PhysicalCores + cpuid.CPU.LogicalCores + cpuid.CPU.CacheLine)
rng := rand.New(rand.NewSource(seed))
// Start from Chrome 144 released on 2026.1.13
releaseDate := time.Date(2026, 1, 13, 0, 0, 0, 0, time.UTC)
version := 144
now := time.Now()
// Each version has random 25-45 day interval
for releaseDate.Before(now) {
releaseDate = releaseDate.AddDate(0, 0, rng.Intn(21)+25)
version++
}
return version - 1
}
// ChromeUA provides default browser User-Agent based on CPU-seeded PRNG.
var ChromeUA = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/" + strconv.Itoa(ChromeVersion()) + ".0.0.0 Safari/537.36"

24
common/utils/padding.go Normal file
View File

@@ -0,0 +1,24 @@
package utils
import (
"math/rand/v2"
)
var (
// 8 ÷ (397/62)
h2packCorrectionFactor = 1.2493702770780857
base62TotalCharsNum = 62
base62Chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
)
// H2Base62Pad generates a base62 padding string for HTTP/2 header
// The total len will be slightly longer than the input to match the length after h2(h3 also) header huffman encoding
func H2Base62Pad[T int32 | int64 | int](expectedLen T) string {
actualLenFloat := float64(expectedLen) * h2packCorrectionFactor
actualLen := int(actualLenFloat)
result := make([]byte, actualLen)
for i := range actualLen {
result[i] = base62Chars[rand.N(base62TotalCharsNum)]
}
return string(result)
}

View File

@@ -18,8 +18,8 @@ import (
var (
Version_x byte = 26
Version_y byte = 1
Version_z byte = 23
Version_y byte = 2
Version_z byte = 6
)
var (

4
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/amnezia-vpn/amnezia-xray-core
go 1.25.6
go 1.25.7
require (
github.com/apernet/quic-go v0.57.2-0.20260111184307-eec823306178
@@ -9,6 +9,7 @@ require (
github.com/golang/mock v1.7.0-rc.1
github.com/google/go-cmp v0.7.0
github.com/gorilla/websocket v1.5.3
github.com/klauspost/cpuid/v2 v2.0.12
github.com/miekg/dns v1.1.72
github.com/pelletier/go-toml v1.9.5
github.com/pires/go-proxyproto v0.9.2
@@ -39,7 +40,6 @@ require (
github.com/google/btree v1.1.2 // indirect
github.com/juju/ratelimit v1.0.2 // indirect
github.com/klauspost/compress v1.17.4 // indirect
github.com/klauspost/cpuid/v2 v2.0.12 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/quic-go/qpack v0.6.0 // indirect

View File

@@ -12,6 +12,7 @@ import (
"github.com/amnezia-vpn/amnezia-xray-core/app/router"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
"github.com/amnezia-vpn/amnezia-xray-core/common/net"
"github.com/amnezia-vpn/amnezia-xray-core/common/platform"
"github.com/amnezia-vpn/amnezia-xray-core/common/platform/filesystem"
"github.com/amnezia-vpn/amnezia-xray-core/common/serial"
"google.golang.org/protobuf/proto"
@@ -204,6 +205,13 @@ func loadIP(file, code string) ([]*router.CIDR, error) {
}
func loadSite(file, code string) ([]*router.Domain, error) {
// Check if domain matcher cache is provided via environment
domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" })
if domainMatcherPath != "" {
return []*router.Domain{{}}, nil
}
bs, err := loadFile(file, code)
if err != nil {
return nil, err

View File

@@ -4,72 +4,18 @@ import (
"sort"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/dns"
"github.com/amnezia-vpn/amnezia-xray-core/common/utils"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/http"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/noop"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/srtp"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/tls"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/utp"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/wechat"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/wireguard"
"google.golang.org/protobuf/proto"
)
type NoOpAuthenticator struct{}
func (NoOpAuthenticator) Build() (proto.Message, error) {
return new(noop.Config), nil
}
type NoOpConnectionAuthenticator struct{}
func (NoOpConnectionAuthenticator) Build() (proto.Message, error) {
return new(noop.ConnectionConfig), nil
}
type SRTPAuthenticator struct{}
func (SRTPAuthenticator) Build() (proto.Message, error) {
return new(srtp.Config), nil
}
type UTPAuthenticator struct{}
func (UTPAuthenticator) Build() (proto.Message, error) {
return new(utp.Config), nil
}
type WechatVideoAuthenticator struct{}
func (WechatVideoAuthenticator) Build() (proto.Message, error) {
return new(wechat.VideoConfig), nil
}
type WireguardAuthenticator struct{}
func (WireguardAuthenticator) Build() (proto.Message, error) {
return new(wireguard.WireguardConfig), nil
}
type DNSAuthenticator struct {
Domain string `json:"domain"`
}
func (v *DNSAuthenticator) Build() (proto.Message, error) {
config := new(dns.Config)
config.Domain = "www.baidu.com"
if len(v.Domain) > 0 {
config.Domain = v.Domain
}
return config, nil
}
type DTLSAuthenticator struct{}
func (DTLSAuthenticator) Build() (proto.Message, error) {
return new(tls.PacketConfig), nil
}
type AuthenticatorRequest struct {
Version string `json:"version"`
Method string `json:"method"`
@@ -95,11 +41,8 @@ func (v *AuthenticatorRequest) Build() (*http.RequestConfig, error) {
Value: []string{"www.baidu.com", "www.bing.com"},
},
{
Name: "User-Agent",
Value: []string{
"Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/53.0.2785.143 Safari/537.36",
"Mozilla/5.0 (iPhone; CPU iPhone OS 10_0_2 like Mac OS X) AppleWebKit/601.1 (KHTML, like Gecko) CriOS/53.0.2785.109 Mobile/14A456 Safari/601.1.46",
},
Name: "User-Agent",
Value: []string{utils.ChromeUA},
},
{
Name: "Accept-Encoding",

View File

@@ -1,6 +1,7 @@
package conf
import (
"context"
"encoding/base64"
"encoding/hex"
"encoding/json"
@@ -10,13 +11,24 @@ import (
"strconv"
"strings"
"syscall"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
"github.com/amnezia-vpn/amnezia-xray-core/common/net"
"github.com/amnezia-vpn/amnezia-xray-core/common/platform/filesystem"
"github.com/amnezia-vpn/amnezia-xray-core/common/serial"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/dns"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/dtls"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/srtp"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/utp"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/wechat"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/wireguard"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/aes128gcm"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/original"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/salamander"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/xdns"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/xicmp"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/httpupgrade"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/hysteria"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/kcp"
@@ -29,16 +41,6 @@ import (
)
var (
kcpHeaderLoader = NewJSONConfigLoader(ConfigCreatorCache{
"none": func() interface{} { return new(NoOpAuthenticator) },
"srtp": func() interface{} { return new(SRTPAuthenticator) },
"utp": func() interface{} { return new(UTPAuthenticator) },
"wechat-video": func() interface{} { return new(WechatVideoAuthenticator) },
"dtls": func() interface{} { return new(DTLSAuthenticator) },
"wireguard": func() interface{} { return new(WireguardAuthenticator) },
"dns": func() interface{} { return new(DNSAuthenticator) },
}, "type", "")
tcpHeaderLoader = NewJSONConfigLoader(ConfigCreatorCache{
"none": func() interface{} { return new(NoOpConnectionAuthenticator) },
"http": func() interface{} { return new(Authenticator) },
@@ -63,9 +65,9 @@ func (c *KCPConfig) Build() (proto.Message, error) {
if c.Mtu != nil {
mtu := *c.Mtu
if mtu < 576 || mtu > 1460 {
return nil, errors.New("invalid mKCP MTU size: ", mtu).AtError()
}
// if mtu < 576 || mtu > 1460 {
// return nil, errors.New("invalid mKCP MTU size: ", mtu).AtError()
// }
config.Mtu = &kcp.MTU{Value: mtu}
}
if c.Tti != nil {
@@ -100,20 +102,8 @@ func (c *KCPConfig) Build() (proto.Message, error) {
config.WriteBuffer = &kcp.WriteBuffer{Size: 512 * 1024}
}
}
if len(c.HeaderConfig) > 0 {
headerConfig, _, err := kcpHeaderLoader.Load(c.HeaderConfig)
if err != nil {
return nil, errors.New("invalid mKCP header config.").Base(err).AtError()
}
ts, err := headerConfig.(Buildable).Build()
if err != nil {
return nil, errors.New("invalid mKCP header config").Base(err).AtError()
}
config.HeaderConfig = serial.ToTypedMessage(ts)
}
if c.Seed != nil {
config.Seed = &kcp.EncryptionSeed{Seed: *c.Seed}
if c.HeaderConfig != nil || c.Seed != nil {
return nil, errors.PrintRemovedFeatureError("mkcp header & seed", "finalmask/udp header-* & mkcp-original & mkcp-aes128gcm")
}
return config, nil
@@ -228,6 +218,19 @@ type SplitHTTPConfig struct {
Mode string `json:"mode"`
Headers map[string]string `json:"headers"`
XPaddingBytes Int32Range `json:"xPaddingBytes"`
XPaddingObfsMode bool `json:"xPaddingObfsMode"`
XPaddingKey string `json:"xPaddingKey"`
XPaddingHeader string `json:"xPaddingHeader"`
XPaddingPlacement string `json:"xPaddingPlacement"`
XPaddingMethod string `json:"xPaddingMethod"`
UplinkHTTPMethod string `json:"uplinkHTTPMethod"`
SessionPlacement string `json:"sessionPlacement"`
SessionKey string `json:"sessionKey"`
SeqPlacement string `json:"seqPlacement"`
SeqKey string `json:"seqKey"`
UplinkDataPlacement string `json:"uplinkDataPlacement"`
UplinkDataKey string `json:"uplinkDataKey"`
UplinkChunkSize uint32 `json:"uplinkChunkSize"`
NoGRPCHeader bool `json:"noGRPCHeader"`
NoSSEHeader bool `json:"noSSEHeader"`
ScMaxEachPostBytes Int32Range `json:"scMaxEachPostBytes"`
@@ -287,6 +290,108 @@ func (c *SplitHTTPConfig) Build() (proto.Message, error) {
return nil, errors.New("xPaddingBytes cannot be disabled")
}
if c.XPaddingKey == "" {
c.XPaddingKey = "x_padding"
}
if c.XPaddingHeader == "" {
c.XPaddingHeader = "X-Padding"
}
switch c.XPaddingPlacement {
case "":
c.XPaddingPlacement = "queryInHeader"
case "cookie", "header", "query", "queryInHeader":
default:
return nil, errors.New("unsupported padding placement: " + c.XPaddingPlacement)
}
switch c.XPaddingMethod {
case "":
c.XPaddingMethod = "repeat-x"
case "repeat-x", "tokenish":
default:
return nil, errors.New("unsupported padding method: " + c.XPaddingMethod)
}
switch c.UplinkDataPlacement {
case "":
c.UplinkDataPlacement = "body"
case "body":
case "cookie", "header":
if c.Mode != "packet-up" {
return nil, errors.New("UplinkDataPlacement can be " + c.UplinkDataPlacement + " only in packet-up mode")
}
default:
return nil, errors.New("unsupported uplink data placement: " + c.UplinkDataPlacement)
}
if c.UplinkHTTPMethod == "" {
c.UplinkHTTPMethod = "POST"
}
c.UplinkHTTPMethod = strings.ToUpper(c.UplinkHTTPMethod)
if c.UplinkHTTPMethod == "GET" && c.Mode != "packet-up" {
return nil, errors.New("uplinkHTTPMethod can be GET only in packet-up mode")
}
switch c.SessionPlacement {
case "":
c.SessionPlacement = "path"
case "path", "cookie", "header", "query":
default:
return nil, errors.New("unsupported session placement: " + c.SessionPlacement)
}
switch c.SeqPlacement {
case "":
c.SeqPlacement = "path"
case "path", "cookie", "header", "query":
if c.SessionPlacement == "path" {
return nil, errors.New("SeqPlacement must be path when SessionPlacement is path")
}
default:
return nil, errors.New("unsupported seq placement: " + c.SeqPlacement)
}
if c.SessionPlacement != "path" && c.SessionKey == "" {
switch c.SessionPlacement {
case "cookie", "query":
c.SessionKey = "x_session"
case "header":
c.SessionKey = "X-Session"
}
}
if c.SeqPlacement != "path" && c.SeqKey == "" {
switch c.SeqPlacement {
case "cookie", "query":
c.SeqKey = "x_seq"
case "header":
c.SeqKey = "X-Seq"
}
}
if c.UplinkDataPlacement != "body" && c.UplinkDataKey == "" {
switch c.UplinkDataPlacement {
case "cookie":
c.UplinkDataKey = "x_data"
case "header":
c.UplinkDataKey = "X-Data"
}
}
if c.UplinkChunkSize == 0 {
switch c.UplinkDataPlacement {
case "cookie":
c.UplinkChunkSize = 3 * 1024 // 3KB
case "header":
c.UplinkChunkSize = 4 * 1024 // 4KB
}
} else if c.UplinkChunkSize < 64 {
c.UplinkChunkSize = 64
}
if c.Xmux.MaxConnections.To > 0 && c.Xmux.MaxConcurrency.To > 0 {
return nil, errors.New("maxConnections cannot be specified together with maxConcurrency")
}
@@ -305,6 +410,19 @@ func (c *SplitHTTPConfig) Build() (proto.Message, error) {
Mode: c.Mode,
Headers: c.Headers,
XPaddingBytes: newRangeConfig(c.XPaddingBytes),
XPaddingObfsMode: c.XPaddingObfsMode,
XPaddingKey: c.XPaddingKey,
XPaddingHeader: c.XPaddingHeader,
XPaddingPlacement: c.XPaddingPlacement,
XPaddingMethod: c.XPaddingMethod,
UplinkHTTPMethod: c.UplinkHTTPMethod,
SessionPlacement: c.SessionPlacement,
SeqPlacement: c.SeqPlacement,
SessionKey: c.SessionKey,
SeqKey: c.SeqKey,
UplinkDataPlacement: c.UplinkDataPlacement,
UplinkDataKey: c.UplinkDataKey,
UplinkChunkSize: c.UplinkChunkSize,
NoGRPCHeader: c.NoGRPCHeader,
NoSSEHeader: c.NoSSEHeader,
ScMaxEachPostBytes: newRangeConfig(c.ScMaxEachPostBytes),
@@ -631,7 +749,12 @@ func (c *TLSConfig) Build() (proto.Message, error) {
config.MasterKeyLog = c.MasterKeyLog
if c.AllowInsecure {
return nil, errors.PrintRemovedFeatureError(`"allowInsecure"`, `"pinnedPeerCertSha256"`)
if time.Now().After(time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC)) {
return nil, errors.PrintRemovedFeatureError(`"allowInsecure"`, `"pinnedPeerCertSha256"`)
} else {
errors.LogWarning(context.Background(), `"allowInsecure" will be removed automatically after 2026-06-01, please use "pinnedPeerCertSha256"(pcs) and "verifyPeerCertByName"(vcn) instead, PLEASE CONTACT YOUR SERVICE PROVIDER (AIRPORT)`)
config.AllowInsecure = true
}
}
if c.PinnedPeerCertSha256 != "" {
for v := range strings.SplitSeq(c.PinnedPeerCertSha256, ",") {
@@ -639,10 +762,14 @@ func (c *TLSConfig) Build() (proto.Message, error) {
if v == "" {
continue
}
hashValue, err := hex.DecodeString(v)
// remove colons for OpenSSL format
hashValue, err := hex.DecodeString(strings.ReplaceAll(v, ":", ""))
if err != nil {
return nil, err
}
if len(hashValue) != 32 {
return nil, errors.New("incorrect pinnedPeerCertSha256 length: ", v)
}
config.PinnedPeerCertSha256 = append(config.PinnedPeerCertSha256, hashValue)
}
}
@@ -1111,10 +1238,81 @@ func (c *SocketConfig) Build() (*internet.SocketConfig, error) {
var (
udpmaskLoader = NewJSONConfigLoader(ConfigCreatorCache{
"salamander": func() interface{} { return new(Salamander) },
"header-dns": func() interface{} { return new(Dns) },
"header-dtls": func() interface{} { return new(Dtls) },
"header-srtp": func() interface{} { return new(Srtp) },
"header-utp": func() interface{} { return new(Utp) },
"header-wechat": func() interface{} { return new(Wechat) },
"header-wireguard": func() interface{} { return new(Wireguard) },
"mkcp-original": func() interface{} { return new(Original) },
"mkcp-aes128gcm": func() interface{} { return new(Aes128Gcm) },
"salamander": func() interface{} { return new(Salamander) },
"xdns": func() interface{} { return new(Xdns) },
"xicmp": func() interface{} { return new(Xicmp) },
}, "type", "settings")
)
type Dns struct {
Domain string `json:"domain"`
}
func (c *Dns) Build() (proto.Message, error) {
config := &dns.Config{}
config.Domain = "www.baidu.com"
if len(c.Domain) > 0 {
config.Domain = c.Domain
}
return config, nil
}
type Dtls struct{}
func (c *Dtls) Build() (proto.Message, error) {
return &dtls.Config{}, nil
}
type Srtp struct{}
func (c *Srtp) Build() (proto.Message, error) {
return &srtp.Config{}, nil
}
type Utp struct{}
func (c *Utp) Build() (proto.Message, error) {
return &utp.Config{}, nil
}
type Wechat struct{}
func (c *Wechat) Build() (proto.Message, error) {
return &wechat.Config{}, nil
}
type Wireguard struct{}
func (c *Wireguard) Build() (proto.Message, error) {
return &wireguard.Config{}, nil
}
type Original struct{}
func (c *Original) Build() (proto.Message, error) {
return &original.Config{}, nil
}
type Aes128Gcm struct {
Password string `json:"password"`
}
func (c *Aes128Gcm) Build() (proto.Message, error) {
return &aes128gcm.Config{
Password: c.Password,
}, nil
}
type Salamander struct {
Password string `json:"password"`
}
@@ -1125,14 +1323,46 @@ func (c *Salamander) Build() (proto.Message, error) {
return config, nil
}
type FinalMask struct {
type Xdns struct {
Domain string `json:"domain"`
}
func (c *Xdns) Build() (proto.Message, error) {
if c.Domain == "" {
return nil, errors.New("empty domain")
}
return &xdns.Config{
Domain: c.Domain,
}, nil
}
type Xicmp struct {
ListenIp string `json:"listenIp"`
Id uint16 `json:"id"`
}
func (c *Xicmp) Build() (proto.Message, error) {
config := &xicmp.Config{
Ip: c.ListenIp,
Id: int32(c.Id),
}
if config.Ip == "" {
config.Ip = "0.0.0.0"
}
return config, nil
}
type Mask struct {
Type string `json:"type"`
Settings *json.RawMessage `json:"settings"`
}
func (c *FinalMask) Build(tcpmaskLoader bool) (proto.Message, error) {
func (c *Mask) Build(tcp bool) (proto.Message, error) {
loader := udpmaskLoader
if tcpmaskLoader {
if tcp {
return nil, errors.New("")
}
@@ -1151,12 +1381,17 @@ func (c *FinalMask) Build(tcpmaskLoader bool) (proto.Message, error) {
return ts, nil
}
type FinalMask struct {
Tcp []Mask `json:"tcp"`
Udp []Mask `json:"udp"`
}
type StreamConfig struct {
Address *Address `json:"address"`
Port uint16 `json:"port"`
Network *TransportProtocol `json:"network"`
Security string `json:"security"`
Udpmasks []*FinalMask `json:"udpmasks"`
FinalMask *FinalMask `json:"finalmask"`
TLSSettings *TLSConfig `json:"tlsSettings"`
REALITYSettings *REALITYConfig `json:"realitySettings"`
RAWSettings *TCPConfig `json:"rawSettings"`
@@ -1306,12 +1541,21 @@ func (c *StreamConfig) Build() (*internet.StreamConfig, error) {
config.SocketSettings = ss
}
for _, mask := range c.Udpmasks {
u, err := mask.Build(false)
if err != nil {
return nil, errors.New("failed to build mask with type ", mask.Type).Base(err)
if c.FinalMask != nil {
for _, mask := range c.FinalMask.Tcp {
u, err := mask.Build(true)
if err != nil {
return nil, errors.New("failed to build mask with type ", mask.Type).Base(err)
}
config.Tcpmasks = append(config.Tcpmasks, serial.ToTypedMessage(u))
}
for _, mask := range c.FinalMask.Udp {
u, err := mask.Build(false)
if err != nil {
return nil, errors.New("failed to build mask with type ", mask.Type).Base(err)
}
config.Udpmasks = append(config.Udpmasks, serial.ToTypedMessage(u))
}
config.Udpmasks = append(config.Udpmasks, serial.ToTypedMessage(u))
}
return config, nil

View File

@@ -1,16 +1,21 @@
package conf
import (
"bytes"
"context"
"encoding/json"
"os"
"path/filepath"
"sort"
"strings"
"github.com/amnezia-vpn/amnezia-xray-core/app/dispatcher"
"github.com/amnezia-vpn/amnezia-xray-core/app/proxyman"
"github.com/amnezia-vpn/amnezia-xray-core/app/router"
"github.com/amnezia-vpn/amnezia-xray-core/app/stats"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
"github.com/amnezia-vpn/amnezia-xray-core/common/net"
"github.com/amnezia-vpn/amnezia-xray-core/common/platform"
"github.com/amnezia-vpn/amnezia-xray-core/common/serial"
core "github.com/amnezia-vpn/amnezia-xray-core/core"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet"
@@ -607,6 +612,187 @@ func (c *Config) Build() (*core.Config, error) {
return config, nil
}
func (c *Config) BuildMPHCache(customMatcherFilePath *string) error {
var geosite []*router.GeoSite
deps := make(map[string][]string)
uniqueGeosites := make(map[string]bool)
uniqueTags := make(map[string]bool)
matcherFilePath := platform.GetAssetLocation("matcher.cache")
if customMatcherFilePath != nil {
matcherFilePath = *customMatcherFilePath
}
processGeosite := func(dStr string) bool {
prefix := ""
if strings.HasPrefix(dStr, "geosite:") {
prefix = "geosite:"
} else if strings.HasPrefix(dStr, "ext-domain:") {
prefix = "ext-domain:"
}
if prefix == "" {
return false
}
key := strings.ToLower(dStr)
country := strings.ToUpper(dStr[len(prefix):])
if !uniqueGeosites[country] {
ds, err := loadGeositeWithAttr("geosite.dat", country)
if err == nil {
uniqueGeosites[country] = true
geosite = append(geosite, &router.GeoSite{CountryCode: key, Domain: ds})
}
}
return true
}
processDomains := func(tag string, rawDomains []string) {
var manualDomains []*router.Domain
var dDeps []string
for _, dStr := range rawDomains {
if processGeosite(dStr) {
dDeps = append(dDeps, strings.ToLower(dStr))
} else {
ds, err := parseDomainRule(dStr)
if err == nil {
manualDomains = append(manualDomains, ds...)
}
}
}
if len(manualDomains) > 0 {
if !uniqueTags[tag] {
uniqueTags[tag] = true
geosite = append(geosite, &router.GeoSite{CountryCode: tag, Domain: manualDomains})
}
}
if len(dDeps) > 0 {
deps[tag] = append(deps[tag], dDeps...)
}
}
// proccess rules
if c.RouterConfig != nil {
for _, rawRule := range c.RouterConfig.RuleList {
type SimpleRule struct {
RuleTag string `json:"ruleTag"`
Domain *StringList `json:"domain"`
Domains *StringList `json:"domains"`
}
var sr SimpleRule
json.Unmarshal(rawRule, &sr)
if sr.RuleTag == "" {
continue
}
var allDomains []string
if sr.Domain != nil {
allDomains = append(allDomains, *sr.Domain...)
}
if sr.Domains != nil {
allDomains = append(allDomains, *sr.Domains...)
}
processDomains(sr.RuleTag, allDomains)
}
}
// proccess dns servers
if c.DNSConfig != nil {
for _, ns := range c.DNSConfig.Servers {
if ns.Tag == "" {
continue
}
processDomains(ns.Tag, ns.Domains)
}
}
var hostIPs map[string][]string
if c.DNSConfig != nil && c.DNSConfig.Hosts != nil {
hostIPs = make(map[string][]string)
var hostDeps []string
var hostPatterns []string
// use raw map to avoid expanding geosites
var domains []string
for domain := range c.DNSConfig.Hosts.Hosts {
domains = append(domains, domain)
}
sort.Strings(domains)
manualHostGroups := make(map[string][]*router.Domain)
manualHostIPs := make(map[string][]string)
manualHostNames := make(map[string]string)
for _, domain := range domains {
ha := c.DNSConfig.Hosts.Hosts[domain]
m := getHostMapping(ha)
var ips []string
if m.ProxiedDomain != "" {
ips = append(ips, m.ProxiedDomain)
} else {
for _, ip := range m.Ip {
ips = append(ips, net.IPAddress(ip).String())
}
}
if processGeosite(domain) {
tag := strings.ToLower(domain)
hostDeps = append(hostDeps, tag)
hostIPs[tag] = ips
hostPatterns = append(hostPatterns, domain)
} else {
// build manual domains by their destination IPs
sort.Strings(ips)
ipKey := strings.Join(ips, ",")
ds, err := parseDomainRule(domain)
if err == nil {
manualHostGroups[ipKey] = append(manualHostGroups[ipKey], ds...)
manualHostIPs[ipKey] = ips
if _, ok := manualHostNames[ipKey]; !ok {
manualHostNames[ipKey] = domain
}
}
}
}
// create manual host groups
var ipKeys []string
for k := range manualHostGroups {
ipKeys = append(ipKeys, k)
}
sort.Strings(ipKeys)
for _, k := range ipKeys {
tag := manualHostNames[k]
geosite = append(geosite, &router.GeoSite{CountryCode: tag, Domain: manualHostGroups[k]})
hostDeps = append(hostDeps, tag)
hostIPs[tag] = manualHostIPs[k]
// record tag _ORDER links the matcher to IP addresses
hostPatterns = append(hostPatterns, tag)
}
deps["HOSTS"] = hostDeps
hostIPs["_ORDER"] = hostPatterns
}
f, err := os.Create(matcherFilePath)
if err != nil {
return err
}
defer f.Close()
var buf bytes.Buffer
if err := router.SerializeGeoSiteList(geosite, deps, hostIPs, &buf); err != nil {
return err
}
if _, err := f.Write(buf.Bytes()); err != nil {
return err
}
return nil
}
// Convert string to Address.
func ParseSendThough(Addr *string) *Address {
var addr Address

View File

@@ -0,0 +1,52 @@
package all
import (
"os"
"github.com/amnezia-vpn/amnezia-xray-core/common/platform"
"github.com/amnezia-vpn/amnezia-xray-core/infra/conf/serial"
"github.com/amnezia-vpn/amnezia-xray-core/main/commands/base"
)
var cmdBuildMphCache = &base.Command{
UsageLine: `{{.Exec}} buildMphCache [-c config.json] [-o domain.cache]`,
Short: `Build domain matcher cache`,
Long: `
Build domain matcher cache from a configuration file.
Example: {{.Exec}} buildMphCache -c config.json -o domain.cache
`,
}
func init() {
cmdBuildMphCache.Run = executeBuildMphCache
}
var (
configPath = cmdBuildMphCache.Flag.String("c", "config.json", "Config file path")
outputPath = cmdBuildMphCache.Flag.String("o", "domain.cache", "Output cache file path")
)
func executeBuildMphCache(cmd *base.Command, args []string) {
cf, err := os.Open(*configPath)
if err != nil {
base.Fatalf("failed to open config file: %v", err)
}
defer cf.Close()
// prevent using existing cache
domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" })
if domainMatcherPath != "" {
os.Setenv("XRAY_MPH_CACHE", "")
defer os.Setenv("XRAY_MPH_CACHE", domainMatcherPath)
}
config, err := serial.DecodeJSONConfig(cf)
if err != nil {
base.Fatalf("failed to decode config file: %v", err)
}
if err := config.BuildMPHCache(outputPath); err != nil {
base.Fatalf("failed to build MPH cache: %v", err)
}
}

View File

@@ -19,5 +19,6 @@ func init() {
cmdMLDSA65,
cmdMLKEM768,
cmdVLESSEnc,
cmdBuildMphCache,
)
}

View File

@@ -0,0 +1,78 @@
package tls
import (
"bytes"
"crypto/x509"
"encoding/pem"
"flag"
"fmt"
"os"
"text/tabwriter"
"github.com/amnezia-vpn/amnezia-xray-core/main/commands/base"
. "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/tls"
)
var cmdHash = &base.Command{
UsageLine: "{{.Exec}} tls hash",
Short: "Calculate TLS certificate hash.",
Long: `
xray tls hash --cert <cert.pem>
Calculate TLS certificate hash.
`,
}
func init() {
cmdHash.Run = executeHash // break init loop
}
var input = cmdHash.Flag.String("cert", "fullchain.pem", "The file path of the certificate")
func executeHash(cmd *base.Command, args []string) {
fs := flag.NewFlagSet("hash", flag.ContinueOnError)
if err := fs.Parse(args); err != nil {
fmt.Println(err)
return
}
certContent, err := os.ReadFile(*input)
if err != nil {
fmt.Println(err)
return
}
var certs []*x509.Certificate
if bytes.Contains(certContent, []byte("BEGIN")) {
for {
block, remain := pem.Decode(certContent)
if block == nil {
break
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
fmt.Println("Unable to decode certificate:", err)
return
}
certs = append(certs, cert)
certContent = remain
}
} else {
certs, err = x509.ParseCertificates(certContent)
if err != nil {
fmt.Println("Unable to parse certificates:", err)
return
}
}
if len(certs) == 0 {
fmt.Println("No certificates found")
return
}
tabWriter := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
for i, cert := range certs {
hash := GenerateCertHashHex(cert)
if i == 0 {
fmt.Fprintf(tabWriter, "Leaf SHA256:\t%s\n", hash)
} else {
fmt.Fprintf(tabWriter, "CA <%s> SHA256:\t%s\n", cert.Subject.CommonName, hash)
}
}
tabWriter.Flush()
}

View File

@@ -1,44 +0,0 @@
package tls
import (
"flag"
"fmt"
"os"
"github.com/amnezia-vpn/amnezia-xray-core/main/commands/base"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/tls"
)
var cmdLeafCertHash = &base.Command{
UsageLine: "{{.Exec}} tls leafCertHash",
Short: "Calculate TLS leaf certificate hash.",
Long: `
xray tls leafCertHash --cert <cert.pem>
Calculate TLS leaf certificate hash.
`,
}
func init() {
cmdLeafCertHash.Run = executeLeafCertHash // break init loop
}
var input = cmdLeafCertHash.Flag.String("cert", "fullchain.pem", "The file path of the leaf certificate")
func executeLeafCertHash(cmd *base.Command, args []string) {
fs := flag.NewFlagSet("leafCertHash", flag.ContinueOnError)
if err := fs.Parse(args); err != nil {
fmt.Println(err)
return
}
certContent, err := os.ReadFile(*input)
if err != nil {
fmt.Println(err)
return
}
certChainHashB64, err := tls.CalculatePEMLeafCertSHA256Hash(certContent)
if err != nil {
fmt.Println("failed to decode cert", err)
return
}
fmt.Println(certChainHashB64)
}

View File

@@ -6,8 +6,13 @@ import (
"encoding/hex"
"fmt"
"net"
"os"
"strconv"
"text/tabwriter"
utls "github.com/refraction-networking/utls"
"github.com/amnezia-vpn/amnezia-xray-core/common/utils"
"github.com/amnezia-vpn/amnezia-xray-core/main/commands/base"
. "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/tls"
)
@@ -46,6 +51,7 @@ func executePing(cmd *base.Command, args []string) {
} else {
TargetPort, _ = strconv.Atoi(port)
}
tabWriter := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
var ip net.IP
if len(*pingIPStr) > 0 {
@@ -70,21 +76,20 @@ func executePing(cmd *base.Command, args []string) {
if err != nil {
base.Fatalf("Failed to dial tcp: %s", err)
}
tlsConn := gotls.Client(tcpConn, &gotls.Config{
tlsConn := GeneraticUClient(tcpConn, &gotls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"h2", "http/1.1"},
MaxVersion: gotls.VersionTLS13,
MinVersion: gotls.VersionTLS12,
// Do not release tool before v5's refactor
// VerifyPeerCertificate: showCert(),
})
err = tlsConn.Handshake()
if err != nil {
fmt.Println("Handshake failure: ", err)
} else {
fmt.Println("Handshake succeeded")
printTLSConnDetail(tlsConn)
printCertificates(tlsConn.ConnectionState().PeerCertificates)
printTLSConnDetail(tabWriter, tlsConn)
printCertificates(tabWriter, tlsConn.ConnectionState().PeerCertificates)
tabWriter.Flush()
}
tlsConn.Close()
}
@@ -96,21 +101,20 @@ func executePing(cmd *base.Command, args []string) {
if err != nil {
base.Fatalf("Failed to dial tcp: %s", err)
}
tlsConn := gotls.Client(tcpConn, &gotls.Config{
tlsConn := GeneraticUClient(tcpConn, &gotls.Config{
ServerName: domain,
NextProtos: []string{"h2", "http/1.1"},
MaxVersion: gotls.VersionTLS13,
MinVersion: gotls.VersionTLS12,
// Do not release tool before v5's refactor
// VerifyPeerCertificate: showCert(),
})
err = tlsConn.Handshake()
if err != nil {
fmt.Println("Handshake failure: ", err)
} else {
fmt.Println("Handshake succeeded")
printTLSConnDetail(tlsConn)
printCertificates(tlsConn.ConnectionState().PeerCertificates)
printTLSConnDetail(tabWriter, tlsConn)
printCertificates(tabWriter, tlsConn.ConnectionState().PeerCertificates)
tabWriter.Flush()
}
tlsConn.Close()
}
@@ -119,51 +123,45 @@ func executePing(cmd *base.Command, args []string) {
fmt.Println("TLS ping finished")
}
func printCertificates(certs []*x509.Certificate) {
func printCertificates(tabWriter *tabwriter.Writer, certs []*x509.Certificate) {
var leaf *x509.Certificate
var CAs []*x509.Certificate
var length int
for _, cert := range certs {
length += len(cert.Raw)
if len(cert.DNSNames) != 0 {
leaf = cert
} else {
CAs = append(CAs, cert)
}
}
fmt.Println("Certificate chain's total length: ", length, "(certs count: "+strconv.Itoa(len(certs))+")")
fmt.Fprintf(tabWriter, "Certificate chain's total length:\t%d (certs count: %s)\n", length, strconv.Itoa(len(certs)))
if leaf != nil {
fmt.Println("Cert's signature algorithm: ", leaf.SignatureAlgorithm.String())
fmt.Println("Cert's publicKey algorithm: ", leaf.PublicKeyAlgorithm.String())
fmt.Println("Cert's allowed domains: ", leaf.DNSNames)
fmt.Fprintf(tabWriter, "Cert's signature algorithm:\t%s\n", leaf.SignatureAlgorithm.String())
fmt.Fprintf(tabWriter, "Cert's publicKey algorithm:\t%s\n", leaf.PublicKeyAlgorithm.String())
fmt.Fprintf(tabWriter, "Cert's leaf SHA256:\t%s\n", hex.EncodeToString(GenerateCertHash(leaf)))
for _, ca := range CAs {
fmt.Fprintf(tabWriter, "Cert's CA <%s> SHA256:\t%s\n", ca.Subject.CommonName, hex.EncodeToString(GenerateCertHash(ca)))
}
fmt.Fprintf(tabWriter, "Cert's allowed domains:\t%v\n", leaf.DNSNames)
}
}
func printTLSConnDetail(tlsConn *gotls.Conn) {
func printTLSConnDetail(tabWriter *tabwriter.Writer, tlsConn *utls.UConn) {
connectionState := tlsConn.ConnectionState()
var tlsVersion string
if connectionState.Version == gotls.VersionTLS13 {
switch connectionState.Version {
case gotls.VersionTLS13:
tlsVersion = "TLS 1.3"
} else if connectionState.Version == gotls.VersionTLS12 {
case gotls.VersionTLS12:
tlsVersion = "TLS 1.2"
}
fmt.Println("TLS Version: ", tlsVersion)
curveID := connectionState.CurveID
if curveID != 0 {
PostQuantum := (curveID == gotls.X25519MLKEM768)
fmt.Println("TLS Post-Quantum key exchange: ", PostQuantum, "("+curveID.String()+")")
fmt.Fprintf(tabWriter, "TLS Version:\t%s\n", tlsVersion)
curveID := utils.AccessField[utls.CurveID](tlsConn.Conn, "curveID")
if curveID != nil {
PostQuantum := (*curveID == utls.X25519MLKEM768)
fmt.Fprintf(tabWriter, "TLS Post-Quantum key exchange:\t%t (%s)\n", PostQuantum, curveID.String())
} else {
fmt.Println("TLS Post-Quantum key exchange: false (RSA Exchange)")
}
}
func showCert() func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
var hash []byte
for _, asn1Data := range rawCerts {
cert, _ := x509.ParseCertificate(asn1Data)
if cert.IsCA {
hash = GenerateCertHash(cert)
}
}
fmt.Println("Certificate Leaf Hash: ", hex.EncodeToString(hash))
return nil
fmt.Fprintf(tabWriter, "TLS Post-Quantum key exchange: false (RSA Exchange)\n")
}
}

View File

@@ -13,7 +13,7 @@ var CmdTLS = &base.Command{
Commands: []*base.Command{
cmdCert,
cmdPing,
cmdLeafCertHash,
cmdHash,
cmdECH,
},
}

View File

@@ -63,11 +63,6 @@ import (
// Transport headers
_ "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/http"
_ "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/noop"
_ "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/srtp"
_ "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/tls"
_ "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/utp"
_ "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/wechat"
_ "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/wireguard"
// JSON & TOML & YAML
_ "github.com/amnezia-vpn/amnezia-xray-core/main/json"

View File

@@ -21,6 +21,7 @@ import (
"github.com/amnezia-vpn/amnezia-xray-core/common/session"
"github.com/amnezia-vpn/amnezia-xray-core/common/signal"
"github.com/amnezia-vpn/amnezia-xray-core/common/task"
"github.com/amnezia-vpn/amnezia-xray-core/common/utils"
"github.com/amnezia-vpn/amnezia-xray-core/core"
"github.com/amnezia-vpn/amnezia-xray-core/features/policy"
"github.com/amnezia-vpn/amnezia-xray-core/transport"
@@ -219,6 +220,9 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u
for _, h := range header {
req.Header.Set(h.Key, h.Value)
}
if req.Header.Get("User-Agent") == "" {
req.Header.Set("User-Agent", utils.ChromeUA)
}
connectHTTP1 := func(rawConn net.Conn) (net.Conn, error) {
req.Header.Set("Proxy-Connection", "Keep-Alive")

View File

@@ -4,17 +4,15 @@ import (
"net"
)
type ConnSize interface {
Size() int32
}
type Udpmask interface {
UDP()
WrapConnClient(net.Conn) (net.Conn, error)
WrapConnServer(net.Conn) (net.Conn, error)
WrapPacketConnClient(net.PacketConn) (net.PacketConn, error)
WrapPacketConnServer(net.PacketConn) (net.PacketConn, error)
Size() int
Serialize([]byte)
WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error)
WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error)
}
type UdpmaskManager struct {
@@ -27,66 +25,32 @@ func NewUdpmaskManager(udpmasks []Udpmask) *UdpmaskManager {
}
}
func (m *UdpmaskManager) WrapConnClient(raw net.Conn) (net.Conn, error) {
var err error
for _, mask := range m.udpmasks {
raw, err = mask.WrapConnClient(raw)
if err != nil {
return nil, err
}
}
return raw, nil
}
func (m *UdpmaskManager) WrapConnServer(raw net.Conn) (net.Conn, error) {
var err error
for _, mask := range m.udpmasks {
raw, err = mask.WrapConnServer(raw)
if err != nil {
return nil, err
}
}
return raw, nil
}
func (m *UdpmaskManager) WrapPacketConnClient(raw net.PacketConn) (net.PacketConn, error) {
leaveSize := int32(0)
var err error
for _, mask := range m.udpmasks {
raw, err = mask.WrapPacketConnClient(raw)
for i, mask := range m.udpmasks {
raw, err = mask.WrapPacketConnClient(raw, i == len(m.udpmasks)-1, leaveSize, i == 0)
if err != nil {
return nil, err
}
leaveSize += raw.(ConnSize).Size()
}
return raw, nil
}
func (m *UdpmaskManager) WrapPacketConnServer(raw net.PacketConn) (net.PacketConn, error) {
leaveSize := int32(0)
var err error
for _, mask := range m.udpmasks {
raw, err = mask.WrapPacketConnServer(raw)
for i, mask := range m.udpmasks {
raw, err = mask.WrapPacketConnServer(raw, i == len(m.udpmasks)-1, leaveSize, i == 0)
if err != nil {
return nil, err
}
leaveSize += raw.(ConnSize).Size()
}
return raw, nil
}
func (m *UdpmaskManager) Size() int {
size := 0
for _, mask := range m.udpmasks {
size += mask.Size()
}
return size
}
func (m *UdpmaskManager) Serialize(b []byte) {
index := 0
for _, mask := range m.udpmasks {
mask.Serialize(b[index:])
index += mask.Size()
}
}
type Tcpmask interface {
TCP()

View File

@@ -0,0 +1,16 @@
package dns
import (
"net"
)
func (c *Config) UDP() {
}
func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnServer(c, raw, first, leaveSize)
}

View File

@@ -0,0 +1,123 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.11
// protoc v3.21.12
// source: transport/internet/finalmask/header/dns/config.proto
package dns
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type Config struct {
state protoimpl.MessageState `protogen:"open.v1"`
Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Config) Reset() {
*x = Config{}
mi := &file_transport_internet_finalmask_header_dns_config_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Config) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Config) ProtoMessage() {}
func (x *Config) ProtoReflect() protoreflect.Message {
mi := &file_transport_internet_finalmask_header_dns_config_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Config.ProtoReflect.Descriptor instead.
func (*Config) Descriptor() ([]byte, []int) {
return file_transport_internet_finalmask_header_dns_config_proto_rawDescGZIP(), []int{0}
}
func (x *Config) GetDomain() string {
if x != nil {
return x.Domain
}
return ""
}
var File_transport_internet_finalmask_header_dns_config_proto protoreflect.FileDescriptor
const file_transport_internet_finalmask_header_dns_config_proto_rawDesc = "" +
"\n" +
"4transport/internet/finalmask/header/dns/config.proto\x12,xray.transport.internet.finalmask.header.dns\" \n" +
"\x06Config\x12\x16\n" +
"\x06domain\x18\x01 \x01(\tR\x06domainB\xb5\x01\n" +
"0com.xray.transport.internet.finalmask.header.dnsP\x01ZPgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/dns\xaa\x02,Xray.Transport.Internet.Finalmask.Header.Dnsb\x06proto3"
var (
file_transport_internet_finalmask_header_dns_config_proto_rawDescOnce sync.Once
file_transport_internet_finalmask_header_dns_config_proto_rawDescData []byte
)
func file_transport_internet_finalmask_header_dns_config_proto_rawDescGZIP() []byte {
file_transport_internet_finalmask_header_dns_config_proto_rawDescOnce.Do(func() {
file_transport_internet_finalmask_header_dns_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_dns_config_proto_rawDesc), len(file_transport_internet_finalmask_header_dns_config_proto_rawDesc)))
})
return file_transport_internet_finalmask_header_dns_config_proto_rawDescData
}
var file_transport_internet_finalmask_header_dns_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_transport_internet_finalmask_header_dns_config_proto_goTypes = []any{
(*Config)(nil), // 0: xray.transport.internet.finalmask.header.dns.Config
}
var file_transport_internet_finalmask_header_dns_config_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_transport_internet_finalmask_header_dns_config_proto_init() }
func file_transport_internet_finalmask_header_dns_config_proto_init() {
if File_transport_internet_finalmask_header_dns_config_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_dns_config_proto_rawDesc), len(file_transport_internet_finalmask_header_dns_config_proto_rawDesc)),
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_transport_internet_finalmask_header_dns_config_proto_goTypes,
DependencyIndexes: file_transport_internet_finalmask_header_dns_config_proto_depIdxs,
MessageInfos: file_transport_internet_finalmask_header_dns_config_proto_msgTypes,
}.Build()
File_transport_internet_finalmask_header_dns_config_proto = out.File
file_transport_internet_finalmask_header_dns_config_proto_goTypes = nil
file_transport_internet_finalmask_header_dns_config_proto_depIdxs = nil
}

View File

@@ -0,0 +1,11 @@
syntax = "proto3";
package xray.transport.internet.finalmask.header.dns;
option csharp_namespace = "Xray.Transport.Internet.Finalmask.Header.Dns";
option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/dns";
option java_package = "com.xray.transport.internet.finalmask.header.dns";
option java_multiple_files = true;
message Config {
string domain = 1;
}

View File

@@ -0,0 +1,241 @@
package dns
import (
"encoding/binary"
"io"
"net"
sync "sync"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/common/dice"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
)
func packDomainName(s string, msg []byte) (off1 int, err error) {
off := 0
ls := len(s)
// Each dot ends a segment of the name.
// We trade each dot byte for a length byte.
// Except for escaped dots (\.), which are normal dots.
// There is also a trailing zero.
// Emit sequence of counted strings, chopping at dots.
var (
begin int
bs []byte
)
for i := 0; i < ls; i++ {
var c byte
if bs == nil {
c = s[i]
} else {
c = bs[i]
}
switch c {
case '\\':
if off+1 > len(msg) {
return len(msg), errors.New("buffer size too small")
}
if bs == nil {
bs = []byte(s)
}
copy(bs[i:ls-1], bs[i+1:])
ls--
case '.':
labelLen := i - begin
if labelLen >= 1<<6 { // top two bits of length must be clear
return len(msg), errors.New("bad rdata")
}
// off can already (we're in a loop) be bigger than len(msg)
// this happens when a name isn't fully qualified
if off+1+labelLen > len(msg) {
return len(msg), errors.New("buffer size too small")
}
// The following is covered by the length check above.
msg[off] = byte(labelLen)
if bs == nil {
copy(msg[off+1:], s[begin:i])
} else {
copy(msg[off+1:], bs[begin:i])
}
off += 1 + labelLen
begin = i + 1
default:
}
}
if off < len(msg) {
msg[off] = 0
}
return off + 1, nil
}
type dns struct {
header []byte
}
func (h *dns) Size() int32 {
return int32(len(h.header))
}
func (h *dns) Serialize(b []byte) {
copy(b, h.header)
binary.BigEndian.PutUint16(b[0:], dice.RollUint16())
}
type dnsConn struct {
first bool
leaveSize int32
conn net.PacketConn
header *dns
readBuf []byte
readMutex sync.Mutex
writeBuf []byte
writeMutex sync.Mutex
}
func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
var header []byte
header = binary.BigEndian.AppendUint16(header, 0x0000) // Transaction ID
header = binary.BigEndian.AppendUint16(header, 0x0100) // Flags: Standard query
header = binary.BigEndian.AppendUint16(header, 0x0001) // Questions
header = binary.BigEndian.AppendUint16(header, 0x0000) // Answer RRs
header = binary.BigEndian.AppendUint16(header, 0x0000) // Authority RRs
header = binary.BigEndian.AppendUint16(header, 0x0000) // Additional RRs
buf := make([]byte, 0x100)
off1, err := packDomainName(c.Domain+".", buf)
if err != nil {
return nil, err
}
header = append(header, buf[:off1]...)
header = binary.BigEndian.AppendUint16(header, 0x0001) // Type: A
header = binary.BigEndian.AppendUint16(header, 0x0001) // Class: IN
conn := &dnsConn{
first: first,
leaveSize: leaveSize,
conn: raw,
header: &dns{
header: header,
},
}
if first {
conn.readBuf = make([]byte, 8192)
conn.writeBuf = make([]byte, 8192)
}
return conn, nil
}
func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *dnsConn) Size() int32 {
return c.header.Size()
}
func (c *dnsConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if c.first {
c.readMutex.Lock()
n, addr, err = c.conn.ReadFrom(c.readBuf)
if err != nil {
c.readMutex.Unlock()
return n, addr, err
}
if n < int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
if len(p) < n-int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
copy(p, c.readBuf[c.Size():n])
c.readMutex.Unlock()
return n - int(c.Size()), addr, err
}
n, addr, err = c.conn.ReadFrom(p)
if err != nil {
return n, addr, err
}
if n < int(c.Size()) {
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
copy(p, p[c.Size():n])
return n - int(c.Size()), addr, err
}
func (c *dnsConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if c.first {
if c.leaveSize+c.Size()+int32(len(p)) > 8192 {
return 0, errors.New("too many masks")
}
c.writeMutex.Lock()
n = copy(c.writeBuf[c.leaveSize+c.Size():], p)
n += int(c.leaveSize) + int(c.Size())
c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()])
nn, err := c.conn.WriteTo(c.writeBuf[:n], addr)
if err != nil {
c.writeMutex.Unlock()
return 0, err
}
if nn != n {
c.writeMutex.Unlock()
return 0, errors.New("nn != n")
}
c.writeMutex.Unlock()
return len(p), nil
}
c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()])
return c.conn.WriteTo(p, addr)
}
func (c *dnsConn) Close() error {
return c.conn.Close()
}
func (c *dnsConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *dnsConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *dnsConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *dnsConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}

View File

@@ -0,0 +1,16 @@
package dtls
import (
"net"
)
func (c *Config) UDP() {
}
func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnServer(c, raw, first, leaveSize)
}

View File

@@ -0,0 +1,114 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.11
// protoc v3.21.12
// source: transport/internet/finalmask/header/dtls/config.proto
package dtls
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type Config struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Config) Reset() {
*x = Config{}
mi := &file_transport_internet_finalmask_header_dtls_config_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Config) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Config) ProtoMessage() {}
func (x *Config) ProtoReflect() protoreflect.Message {
mi := &file_transport_internet_finalmask_header_dtls_config_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Config.ProtoReflect.Descriptor instead.
func (*Config) Descriptor() ([]byte, []int) {
return file_transport_internet_finalmask_header_dtls_config_proto_rawDescGZIP(), []int{0}
}
var File_transport_internet_finalmask_header_dtls_config_proto protoreflect.FileDescriptor
const file_transport_internet_finalmask_header_dtls_config_proto_rawDesc = "" +
"\n" +
"5transport/internet/finalmask/header/dtls/config.proto\x12-xray.transport.internet.finalmask.header.dtls\"\b\n" +
"\x06ConfigB\xb8\x01\n" +
"1com.xray.transport.internet.finalmask.header.dtlsP\x01ZQgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/dtls\xaa\x02-Xray.Transport.Internet.Finalmask.Header.Dtlsb\x06proto3"
var (
file_transport_internet_finalmask_header_dtls_config_proto_rawDescOnce sync.Once
file_transport_internet_finalmask_header_dtls_config_proto_rawDescData []byte
)
func file_transport_internet_finalmask_header_dtls_config_proto_rawDescGZIP() []byte {
file_transport_internet_finalmask_header_dtls_config_proto_rawDescOnce.Do(func() {
file_transport_internet_finalmask_header_dtls_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_dtls_config_proto_rawDesc), len(file_transport_internet_finalmask_header_dtls_config_proto_rawDesc)))
})
return file_transport_internet_finalmask_header_dtls_config_proto_rawDescData
}
var file_transport_internet_finalmask_header_dtls_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_transport_internet_finalmask_header_dtls_config_proto_goTypes = []any{
(*Config)(nil), // 0: xray.transport.internet.finalmask.header.dtls.Config
}
var file_transport_internet_finalmask_header_dtls_config_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_transport_internet_finalmask_header_dtls_config_proto_init() }
func file_transport_internet_finalmask_header_dtls_config_proto_init() {
if File_transport_internet_finalmask_header_dtls_config_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_dtls_config_proto_rawDesc), len(file_transport_internet_finalmask_header_dtls_config_proto_rawDesc)),
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_transport_internet_finalmask_header_dtls_config_proto_goTypes,
DependencyIndexes: file_transport_internet_finalmask_header_dtls_config_proto_depIdxs,
MessageInfos: file_transport_internet_finalmask_header_dtls_config_proto_msgTypes,
}.Build()
File_transport_internet_finalmask_header_dtls_config_proto = out.File
file_transport_internet_finalmask_header_dtls_config_proto_goTypes = nil
file_transport_internet_finalmask_header_dtls_config_proto_depIdxs = nil
}

View File

@@ -0,0 +1,9 @@
syntax = "proto3";
package xray.transport.internet.finalmask.header.dtls;
option csharp_namespace = "Xray.Transport.Internet.Finalmask.Header.Dtls";
option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/dtls";
option java_package = "com.xray.transport.internet.finalmask.header.dtls";
option java_multiple_files = true;
message Config {}

View File

@@ -0,0 +1,178 @@
package dtls
import (
"io"
"net"
sync "sync"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/common/dice"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
)
type dtls struct {
epoch uint16
length uint16
sequence uint32
}
func (*dtls) Size() int32 {
return 1 + 2 + 2 + 6 + 2
}
func (h *dtls) Serialize(b []byte) {
b[0] = 23
b[1] = 254
b[2] = 253
b[3] = byte(h.epoch >> 8)
b[4] = byte(h.epoch)
b[5] = 0
b[6] = 0
b[7] = byte(h.sequence >> 24)
b[8] = byte(h.sequence >> 16)
b[9] = byte(h.sequence >> 8)
b[10] = byte(h.sequence)
h.sequence++
b[11] = byte(h.length >> 8)
b[12] = byte(h.length)
h.length += 17
if h.length > 100 {
h.length -= 50
}
}
type dtlsConn struct {
first bool
leaveSize int32
conn net.PacketConn
header *dtls
readBuf []byte
readMutex sync.Mutex
writeBuf []byte
writeMutex sync.Mutex
}
func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
conn := &dtlsConn{
first: first,
leaveSize: leaveSize,
conn: raw,
header: &dtls{
epoch: dice.RollUint16(),
sequence: 0,
length: 17,
},
}
if first {
conn.readBuf = make([]byte, 8192)
conn.writeBuf = make([]byte, 8192)
}
return conn, nil
}
func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *dtlsConn) Size() int32 {
return c.header.Size()
}
func (c *dtlsConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if c.first {
c.readMutex.Lock()
n, addr, err = c.conn.ReadFrom(c.readBuf)
if err != nil {
c.readMutex.Unlock()
return n, addr, err
}
if n < int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
if len(p) < n-int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
copy(p, c.readBuf[c.Size():n])
c.readMutex.Unlock()
return n - int(c.Size()), addr, err
}
n, addr, err = c.conn.ReadFrom(p)
if err != nil {
return n, addr, err
}
if n < int(c.Size()) {
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
copy(p, p[c.Size():n])
return n - int(c.Size()), addr, err
}
func (c *dtlsConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if c.first {
if c.leaveSize+c.Size()+int32(len(p)) > 8192 {
return 0, errors.New("too many masks")
}
c.writeMutex.Lock()
n = copy(c.writeBuf[c.leaveSize+c.Size():], p)
n += int(c.leaveSize) + int(c.Size())
c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()])
nn, err := c.conn.WriteTo(c.writeBuf[:n], addr)
if err != nil {
c.writeMutex.Unlock()
return 0, err
}
if nn != n {
c.writeMutex.Unlock()
return 0, errors.New("nn != n")
}
c.writeMutex.Unlock()
return len(p), nil
}
c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()])
return c.conn.WriteTo(p, addr)
}
func (c *dtlsConn) Close() error {
return c.conn.Close()
}
func (c *dtlsConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *dtlsConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *dtlsConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *dtlsConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}

View File

@@ -0,0 +1,16 @@
package srtp
import (
"net"
)
func (c *Config) UDP() {
}
func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnServer(c, raw, first, leaveSize)
}

View File

@@ -0,0 +1,114 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.11
// protoc v3.21.12
// source: transport/internet/finalmask/header/srtp/config.proto
package srtp
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type Config struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Config) Reset() {
*x = Config{}
mi := &file_transport_internet_finalmask_header_srtp_config_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Config) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Config) ProtoMessage() {}
func (x *Config) ProtoReflect() protoreflect.Message {
mi := &file_transport_internet_finalmask_header_srtp_config_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Config.ProtoReflect.Descriptor instead.
func (*Config) Descriptor() ([]byte, []int) {
return file_transport_internet_finalmask_header_srtp_config_proto_rawDescGZIP(), []int{0}
}
var File_transport_internet_finalmask_header_srtp_config_proto protoreflect.FileDescriptor
const file_transport_internet_finalmask_header_srtp_config_proto_rawDesc = "" +
"\n" +
"5transport/internet/finalmask/header/srtp/config.proto\x12-xray.transport.internet.finalmask.header.srtp\"\b\n" +
"\x06ConfigB\xb8\x01\n" +
"1com.xray.transport.internet.finalmask.header.srtpP\x01ZQgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/srtp\xaa\x02-Xray.Transport.Internet.Finalmask.Header.Srtpb\x06proto3"
var (
file_transport_internet_finalmask_header_srtp_config_proto_rawDescOnce sync.Once
file_transport_internet_finalmask_header_srtp_config_proto_rawDescData []byte
)
func file_transport_internet_finalmask_header_srtp_config_proto_rawDescGZIP() []byte {
file_transport_internet_finalmask_header_srtp_config_proto_rawDescOnce.Do(func() {
file_transport_internet_finalmask_header_srtp_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_srtp_config_proto_rawDesc), len(file_transport_internet_finalmask_header_srtp_config_proto_rawDesc)))
})
return file_transport_internet_finalmask_header_srtp_config_proto_rawDescData
}
var file_transport_internet_finalmask_header_srtp_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_transport_internet_finalmask_header_srtp_config_proto_goTypes = []any{
(*Config)(nil), // 0: xray.transport.internet.finalmask.header.srtp.Config
}
var file_transport_internet_finalmask_header_srtp_config_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_transport_internet_finalmask_header_srtp_config_proto_init() }
func file_transport_internet_finalmask_header_srtp_config_proto_init() {
if File_transport_internet_finalmask_header_srtp_config_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_srtp_config_proto_rawDesc), len(file_transport_internet_finalmask_header_srtp_config_proto_rawDesc)),
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_transport_internet_finalmask_header_srtp_config_proto_goTypes,
DependencyIndexes: file_transport_internet_finalmask_header_srtp_config_proto_depIdxs,
MessageInfos: file_transport_internet_finalmask_header_srtp_config_proto_msgTypes,
}.Build()
File_transport_internet_finalmask_header_srtp_config_proto = out.File
file_transport_internet_finalmask_header_srtp_config_proto_goTypes = nil
file_transport_internet_finalmask_header_srtp_config_proto_depIdxs = nil
}

View File

@@ -0,0 +1,9 @@
syntax = "proto3";
package xray.transport.internet.finalmask.header.srtp;
option csharp_namespace = "Xray.Transport.Internet.Finalmask.Header.Srtp";
option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/srtp";
option java_package = "com.xray.transport.internet.finalmask.header.srtp";
option java_multiple_files = true;
message Config {}

View File

@@ -0,0 +1,162 @@
package srtp
import (
"encoding/binary"
"io"
"net"
sync "sync"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/common/dice"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
)
type srtp struct {
header uint16
number uint16
}
func (*srtp) Size() int32 {
return 4
}
func (h *srtp) Serialize(b []byte) {
h.number++
binary.BigEndian.PutUint16(b, h.header)
binary.BigEndian.PutUint16(b[2:], h.number)
}
type srtpConn struct {
first bool
leaveSize int32
conn net.PacketConn
header *srtp
readBuf []byte
readMutex sync.Mutex
writeBuf []byte
writeMutex sync.Mutex
}
func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
conn := &srtpConn{
first: first,
leaveSize: leaveSize,
conn: raw,
header: &srtp{
header: 0xB5E8,
number: dice.RollUint16(),
},
}
if first {
conn.readBuf = make([]byte, 8192)
conn.writeBuf = make([]byte, 8192)
}
return conn, nil
}
func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *srtpConn) Size() int32 {
return c.header.Size()
}
func (c *srtpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if c.first {
c.readMutex.Lock()
n, addr, err = c.conn.ReadFrom(c.readBuf)
if err != nil {
c.readMutex.Unlock()
return n, addr, err
}
if n < int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
if len(p) < n-int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
copy(p, c.readBuf[c.Size():n])
c.readMutex.Unlock()
return n - int(c.Size()), addr, err
}
n, addr, err = c.conn.ReadFrom(p)
if err != nil {
return n, addr, err
}
if n < int(c.Size()) {
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
copy(p, p[c.Size():n])
return n - int(c.Size()), addr, err
}
func (c *srtpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if c.first {
if c.leaveSize+c.Size()+int32(len(p)) > 8192 {
return 0, errors.New("too many masks")
}
c.writeMutex.Lock()
n = copy(c.writeBuf[c.leaveSize+c.Size():], p)
n += int(c.leaveSize) + int(c.Size())
c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()])
nn, err := c.conn.WriteTo(c.writeBuf[:n], addr)
if err != nil {
c.writeMutex.Unlock()
return 0, err
}
if nn != n {
c.writeMutex.Unlock()
return 0, errors.New("nn != n")
}
c.writeMutex.Unlock()
return len(p), nil
}
c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()])
return c.conn.WriteTo(p, addr)
}
func (c *srtpConn) Close() error {
return c.conn.Close()
}
func (c *srtpConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *srtpConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *srtpConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *srtpConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}

View File

@@ -0,0 +1,16 @@
package utp
import (
"net"
)
func (c *Config) UDP() {
}
func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnServer(c, raw, first, leaveSize)
}

View File

@@ -0,0 +1,114 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.11
// protoc v3.21.12
// source: transport/internet/finalmask/header/utp/config.proto
package utp
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type Config struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Config) Reset() {
*x = Config{}
mi := &file_transport_internet_finalmask_header_utp_config_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Config) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Config) ProtoMessage() {}
func (x *Config) ProtoReflect() protoreflect.Message {
mi := &file_transport_internet_finalmask_header_utp_config_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Config.ProtoReflect.Descriptor instead.
func (*Config) Descriptor() ([]byte, []int) {
return file_transport_internet_finalmask_header_utp_config_proto_rawDescGZIP(), []int{0}
}
var File_transport_internet_finalmask_header_utp_config_proto protoreflect.FileDescriptor
const file_transport_internet_finalmask_header_utp_config_proto_rawDesc = "" +
"\n" +
"4transport/internet/finalmask/header/utp/config.proto\x12,xray.transport.internet.finalmask.header.utp\"\b\n" +
"\x06ConfigB\xb5\x01\n" +
"0com.xray.transport.internet.finalmask.header.utpP\x01ZPgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/utp\xaa\x02,Xray.Transport.Internet.Finalmask.Header.Utpb\x06proto3"
var (
file_transport_internet_finalmask_header_utp_config_proto_rawDescOnce sync.Once
file_transport_internet_finalmask_header_utp_config_proto_rawDescData []byte
)
func file_transport_internet_finalmask_header_utp_config_proto_rawDescGZIP() []byte {
file_transport_internet_finalmask_header_utp_config_proto_rawDescOnce.Do(func() {
file_transport_internet_finalmask_header_utp_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_utp_config_proto_rawDesc), len(file_transport_internet_finalmask_header_utp_config_proto_rawDesc)))
})
return file_transport_internet_finalmask_header_utp_config_proto_rawDescData
}
var file_transport_internet_finalmask_header_utp_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_transport_internet_finalmask_header_utp_config_proto_goTypes = []any{
(*Config)(nil), // 0: xray.transport.internet.finalmask.header.utp.Config
}
var file_transport_internet_finalmask_header_utp_config_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_transport_internet_finalmask_header_utp_config_proto_init() }
func file_transport_internet_finalmask_header_utp_config_proto_init() {
if File_transport_internet_finalmask_header_utp_config_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_utp_config_proto_rawDesc), len(file_transport_internet_finalmask_header_utp_config_proto_rawDesc)),
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_transport_internet_finalmask_header_utp_config_proto_goTypes,
DependencyIndexes: file_transport_internet_finalmask_header_utp_config_proto_depIdxs,
MessageInfos: file_transport_internet_finalmask_header_utp_config_proto_msgTypes,
}.Build()
File_transport_internet_finalmask_header_utp_config_proto = out.File
file_transport_internet_finalmask_header_utp_config_proto_goTypes = nil
file_transport_internet_finalmask_header_utp_config_proto_depIdxs = nil
}

View File

@@ -0,0 +1,9 @@
syntax = "proto3";
package xray.transport.internet.finalmask.header.utp;
option csharp_namespace = "Xray.Transport.Internet.Finalmask.Header.Utp";
option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/utp";
option java_package = "com.xray.transport.internet.finalmask.header.utp";
option java_multiple_files = true;
message Config {}

View File

@@ -0,0 +1,164 @@
package utp
import (
"encoding/binary"
"io"
"net"
sync "sync"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/common/dice"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
)
type utp struct {
header byte
extension byte
connectionID uint16
}
func (*utp) Size() int32 {
return 4
}
func (h *utp) Serialize(b []byte) {
binary.BigEndian.PutUint16(b, h.connectionID)
b[2] = h.header
b[3] = h.extension
}
type utpConn struct {
first bool
leaveSize int32
conn net.PacketConn
header *utp
readBuf []byte
readMutex sync.Mutex
writeBuf []byte
writeMutex sync.Mutex
}
func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
conn := &utpConn{
first: first,
leaveSize: leaveSize,
conn: raw,
header: &utp{
header: 1,
extension: 0,
connectionID: dice.RollUint16(),
},
}
if first {
conn.readBuf = make([]byte, 8192)
conn.writeBuf = make([]byte, 8192)
}
return conn, nil
}
func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *utpConn) Size() int32 {
return c.header.Size()
}
func (c *utpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if c.first {
c.readMutex.Lock()
n, addr, err = c.conn.ReadFrom(c.readBuf)
if err != nil {
c.readMutex.Unlock()
return n, addr, err
}
if n < int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
if len(p) < n-int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
copy(p, c.readBuf[c.Size():n])
c.readMutex.Unlock()
return n - int(c.Size()), addr, err
}
n, addr, err = c.conn.ReadFrom(p)
if err != nil {
return n, addr, err
}
if n < int(c.Size()) {
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
copy(p, p[c.Size():n])
return n - int(c.Size()), addr, err
}
func (c *utpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if c.first {
if c.leaveSize+c.Size()+int32(len(p)) > 8192 {
return 0, errors.New("too many masks")
}
c.writeMutex.Lock()
n = copy(c.writeBuf[c.leaveSize+c.Size():], p)
n += int(c.leaveSize) + int(c.Size())
c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()])
nn, err := c.conn.WriteTo(c.writeBuf[:n], addr)
if err != nil {
c.writeMutex.Unlock()
return 0, err
}
if nn != n {
c.writeMutex.Unlock()
return 0, errors.New("nn != n")
}
c.writeMutex.Unlock()
return len(p), nil
}
c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()])
return c.conn.WriteTo(p, addr)
}
func (c *utpConn) Close() error {
return c.conn.Close()
}
func (c *utpConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *utpConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *utpConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *utpConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}

View File

@@ -0,0 +1,16 @@
package wechat
import (
"net"
)
func (c *Config) UDP() {
}
func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnServer(c, raw, first, leaveSize)
}

View File

@@ -0,0 +1,114 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.11
// protoc v3.21.12
// source: transport/internet/finalmask/header/wechat/config.proto
package wechat
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type Config struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Config) Reset() {
*x = Config{}
mi := &file_transport_internet_finalmask_header_wechat_config_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Config) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Config) ProtoMessage() {}
func (x *Config) ProtoReflect() protoreflect.Message {
mi := &file_transport_internet_finalmask_header_wechat_config_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Config.ProtoReflect.Descriptor instead.
func (*Config) Descriptor() ([]byte, []int) {
return file_transport_internet_finalmask_header_wechat_config_proto_rawDescGZIP(), []int{0}
}
var File_transport_internet_finalmask_header_wechat_config_proto protoreflect.FileDescriptor
const file_transport_internet_finalmask_header_wechat_config_proto_rawDesc = "" +
"\n" +
"7transport/internet/finalmask/header/wechat/config.proto\x12/xray.transport.internet.finalmask.header.wechat\"\b\n" +
"\x06ConfigB\xbe\x01\n" +
"3com.xray.transport.internet.finalmask.header.wechatP\x01ZSgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/wechat\xaa\x02/Xray.Transport.Internet.Finalmask.Header.Wechatb\x06proto3"
var (
file_transport_internet_finalmask_header_wechat_config_proto_rawDescOnce sync.Once
file_transport_internet_finalmask_header_wechat_config_proto_rawDescData []byte
)
func file_transport_internet_finalmask_header_wechat_config_proto_rawDescGZIP() []byte {
file_transport_internet_finalmask_header_wechat_config_proto_rawDescOnce.Do(func() {
file_transport_internet_finalmask_header_wechat_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_wechat_config_proto_rawDesc), len(file_transport_internet_finalmask_header_wechat_config_proto_rawDesc)))
})
return file_transport_internet_finalmask_header_wechat_config_proto_rawDescData
}
var file_transport_internet_finalmask_header_wechat_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_transport_internet_finalmask_header_wechat_config_proto_goTypes = []any{
(*Config)(nil), // 0: xray.transport.internet.finalmask.header.wechat.Config
}
var file_transport_internet_finalmask_header_wechat_config_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_transport_internet_finalmask_header_wechat_config_proto_init() }
func file_transport_internet_finalmask_header_wechat_config_proto_init() {
if File_transport_internet_finalmask_header_wechat_config_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_wechat_config_proto_rawDesc), len(file_transport_internet_finalmask_header_wechat_config_proto_rawDesc)),
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_transport_internet_finalmask_header_wechat_config_proto_goTypes,
DependencyIndexes: file_transport_internet_finalmask_header_wechat_config_proto_depIdxs,
MessageInfos: file_transport_internet_finalmask_header_wechat_config_proto_msgTypes,
}.Build()
File_transport_internet_finalmask_header_wechat_config_proto = out.File
file_transport_internet_finalmask_header_wechat_config_proto_goTypes = nil
file_transport_internet_finalmask_header_wechat_config_proto_depIdxs = nil
}

View File

@@ -0,0 +1,9 @@
syntax = "proto3";
package xray.transport.internet.finalmask.header.wechat;
option csharp_namespace = "Xray.Transport.Internet.Finalmask.Header.Wechat";
option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/wechat";
option java_package = "com.xray.transport.internet.finalmask.header.wechat";
option java_multiple_files = true;
message Config {}

View File

@@ -0,0 +1,168 @@
package wechat
import (
"encoding/binary"
"io"
"net"
sync "sync"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/common/dice"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
)
type wechat struct {
sn uint32
}
func (*wechat) Size() int32 {
return 13
}
func (h *wechat) Serialize(b []byte) {
h.sn++
b[0] = 0xa1
b[1] = 0x08
binary.BigEndian.PutUint32(b[2:], h.sn)
b[6] = 0x00
b[7] = 0x10
b[8] = 0x11
b[9] = 0x18
b[10] = 0x30
b[11] = 0x22
b[12] = 0x30
}
type wechatConn struct {
first bool
leaveSize int32
conn net.PacketConn
header *wechat
readBuf []byte
readMutex sync.Mutex
writeBuf []byte
writeMutex sync.Mutex
}
func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
conn := &wechatConn{
first: first,
leaveSize: leaveSize,
conn: raw,
header: &wechat{
sn: uint32(dice.RollUint16()),
},
}
if first {
conn.readBuf = make([]byte, 8192)
conn.writeBuf = make([]byte, 8192)
}
return conn, nil
}
func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *wechatConn) Size() int32 {
return c.header.Size()
}
func (c *wechatConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if c.first {
c.readMutex.Lock()
n, addr, err = c.conn.ReadFrom(c.readBuf)
if err != nil {
c.readMutex.Unlock()
return n, addr, err
}
if n < int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
if len(p) < n-int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
copy(p, c.readBuf[c.Size():n])
c.readMutex.Unlock()
return n - int(c.Size()), addr, err
}
n, addr, err = c.conn.ReadFrom(p)
if err != nil {
return n, addr, err
}
if n < int(c.Size()) {
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
copy(p, p[c.Size():n])
return n - int(c.Size()), addr, err
}
func (c *wechatConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if c.first {
if c.leaveSize+c.Size()+int32(len(p)) > 8192 {
return 0, errors.New("too many masks")
}
c.writeMutex.Lock()
n = copy(c.writeBuf[c.leaveSize+c.Size():], p)
n += int(c.leaveSize) + int(c.Size())
c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()])
nn, err := c.conn.WriteTo(c.writeBuf[:n], addr)
if err != nil {
c.writeMutex.Unlock()
return 0, err
}
if nn != n {
c.writeMutex.Unlock()
return 0, errors.New("nn != n")
}
c.writeMutex.Unlock()
return len(p), nil
}
c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()])
return c.conn.WriteTo(p, addr)
}
func (c *wechatConn) Close() error {
return c.conn.Close()
}
func (c *wechatConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *wechatConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *wechatConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *wechatConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}

View File

@@ -0,0 +1,16 @@
package wireguard
import (
"net"
)
func (c *Config) UDP() {
}
func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnServer(c, raw, first, leaveSize)
}

View File

@@ -0,0 +1,114 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.11
// protoc v3.21.12
// source: transport/internet/finalmask/header/wireguard/config.proto
package wireguard
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type Config struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Config) Reset() {
*x = Config{}
mi := &file_transport_internet_finalmask_header_wireguard_config_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Config) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Config) ProtoMessage() {}
func (x *Config) ProtoReflect() protoreflect.Message {
mi := &file_transport_internet_finalmask_header_wireguard_config_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Config.ProtoReflect.Descriptor instead.
func (*Config) Descriptor() ([]byte, []int) {
return file_transport_internet_finalmask_header_wireguard_config_proto_rawDescGZIP(), []int{0}
}
var File_transport_internet_finalmask_header_wireguard_config_proto protoreflect.FileDescriptor
const file_transport_internet_finalmask_header_wireguard_config_proto_rawDesc = "" +
"\n" +
":transport/internet/finalmask/header/wireguard/config.proto\x122xray.transport.internet.finalmask.header.wireguard\"\b\n" +
"\x06ConfigB\xc7\x01\n" +
"6com.xray.transport.internet.finalmask.header.wireguardP\x01ZVgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/wireguard\xaa\x022Xray.Transport.Internet.Finalmask.Header.Wireguardb\x06proto3"
var (
file_transport_internet_finalmask_header_wireguard_config_proto_rawDescOnce sync.Once
file_transport_internet_finalmask_header_wireguard_config_proto_rawDescData []byte
)
func file_transport_internet_finalmask_header_wireguard_config_proto_rawDescGZIP() []byte {
file_transport_internet_finalmask_header_wireguard_config_proto_rawDescOnce.Do(func() {
file_transport_internet_finalmask_header_wireguard_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_wireguard_config_proto_rawDesc), len(file_transport_internet_finalmask_header_wireguard_config_proto_rawDesc)))
})
return file_transport_internet_finalmask_header_wireguard_config_proto_rawDescData
}
var file_transport_internet_finalmask_header_wireguard_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_transport_internet_finalmask_header_wireguard_config_proto_goTypes = []any{
(*Config)(nil), // 0: xray.transport.internet.finalmask.header.wireguard.Config
}
var file_transport_internet_finalmask_header_wireguard_config_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_transport_internet_finalmask_header_wireguard_config_proto_init() }
func file_transport_internet_finalmask_header_wireguard_config_proto_init() {
if File_transport_internet_finalmask_header_wireguard_config_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_wireguard_config_proto_rawDesc), len(file_transport_internet_finalmask_header_wireguard_config_proto_rawDesc)),
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_transport_internet_finalmask_header_wireguard_config_proto_goTypes,
DependencyIndexes: file_transport_internet_finalmask_header_wireguard_config_proto_depIdxs,
MessageInfos: file_transport_internet_finalmask_header_wireguard_config_proto_msgTypes,
}.Build()
File_transport_internet_finalmask_header_wireguard_config_proto = out.File
file_transport_internet_finalmask_header_wireguard_config_proto_goTypes = nil
file_transport_internet_finalmask_header_wireguard_config_proto_depIdxs = nil
}

View File

@@ -0,0 +1,9 @@
syntax = "proto3";
package xray.transport.internet.finalmask.header.wireguard;
option csharp_namespace = "Xray.Transport.Internet.Finalmask.Header.Wireguard";
option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/wireguard";
option java_package = "com.xray.transport.internet.finalmask.header.wireguard";
option java_multiple_files = true;
message Config {}

View File

@@ -0,0 +1,155 @@
package wireguard
import (
"io"
"net"
sync "sync"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
)
type wireguare struct{}
func (*wireguare) Size() int32 {
return 4
}
func (h *wireguare) Serialize(b []byte) {
b[0] = 0x04
b[1] = 0x00
b[2] = 0x00
b[3] = 0x00
}
type wireguareConn struct {
first bool
leaveSize int32
conn net.PacketConn
header *wireguare
readBuf []byte
readMutex sync.Mutex
writeBuf []byte
writeMutex sync.Mutex
}
func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
conn := &wireguareConn{
first: first,
leaveSize: leaveSize,
conn: raw,
header: &wireguare{},
}
if first {
conn.readBuf = make([]byte, 8192)
conn.writeBuf = make([]byte, 8192)
}
return conn, nil
}
func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *wireguareConn) Size() int32 {
return c.header.Size()
}
func (c *wireguareConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if c.first {
c.readMutex.Lock()
n, addr, err = c.conn.ReadFrom(c.readBuf)
if err != nil {
c.readMutex.Unlock()
return n, addr, err
}
if n < int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
if len(p) < n-int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
copy(p, c.readBuf[c.Size():n])
c.readMutex.Unlock()
return n - int(c.Size()), addr, err
}
n, addr, err = c.conn.ReadFrom(p)
if err != nil {
return n, addr, err
}
if n < int(c.Size()) {
return 0, addr, errors.New("header").Base(io.ErrShortBuffer)
}
copy(p, p[c.Size():n])
return n - int(c.Size()), addr, err
}
func (c *wireguareConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if c.first {
if c.leaveSize+c.Size()+int32(len(p)) > 8192 {
return 0, errors.New("too many masks")
}
c.writeMutex.Lock()
n = copy(c.writeBuf[c.leaveSize+c.Size():], p)
n += int(c.leaveSize) + int(c.Size())
c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()])
nn, err := c.conn.WriteTo(c.writeBuf[:n], addr)
if err != nil {
c.writeMutex.Unlock()
return 0, err
}
if nn != n {
c.writeMutex.Unlock()
return 0, errors.New("nn != n")
}
c.writeMutex.Unlock()
return len(p), nil
}
c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()])
return c.conn.WriteTo(p, addr)
}
func (c *wireguareConn) Close() error {
return c.conn.Close()
}
func (c *wireguareConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *wireguareConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *wireguareConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *wireguareConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}

View File

@@ -0,0 +1,70 @@
package aes128gcm_test
import (
"crypto/rand"
"crypto/sha256"
"testing"
"github.com/amnezia-vpn/amnezia-xray-core/common/crypto"
"github.com/stretchr/testify/assert"
)
func TestAes128GcmSealInPlace(t *testing.T) {
hashedPsk := sha256.Sum256([]byte("psk"))
aead := crypto.NewAesGcm(hashedPsk[:16])
text := []byte("0123456789012")
buf := make([]byte, 8192)
nonceSize := aead.NonceSize()
nonce := buf[:nonceSize]
rand.Read(nonce)
copy(buf[nonceSize:], text)
plaintext := buf[nonceSize : nonceSize+len(text)]
sealed := aead.Seal(nil, nonce, plaintext, nil)
_ = aead.Seal(plaintext[:0], nonce, plaintext, nil)
assert.Equal(t, sealed, buf[nonceSize:nonceSize+aead.Overhead()+len(text)])
}
func encrypted(plain []byte) ([]byte, []byte) {
hashedPsk := sha256.Sum256([]byte("psk"))
aead := crypto.NewAesGcm(hashedPsk[:16])
nonce := make([]byte, 12)
rand.Read(nonce)
return nonce, aead.Seal(nil, nonce, plain, nil)
}
func TestAes128GcmOpenInPlace(t *testing.T) {
a, b := encrypted([]byte("0123456789012"))
buf := make([]byte, 8192)
copy(buf, a)
copy(buf[len(a):], b)
hashedPsk := sha256.Sum256([]byte("psk"))
aead := crypto.NewAesGcm(hashedPsk[:16])
nonceSize := aead.NonceSize()
nonce := buf[:nonceSize]
ciphertext := buf[nonceSize : nonceSize+len(b)]
opened, _ := aead.Open(nil, nonce, ciphertext, nil)
_, _ = aead.Open(ciphertext[:0], nonce, ciphertext, nil)
assert.Equal(t, opened, ciphertext[:len(ciphertext)-aead.Overhead()])
}
func TestAes128GcmBounce(t *testing.T) {
hashedPsk := sha256.Sum256([]byte("psk"))
aead := crypto.NewAesGcm(hashedPsk[:16])
buf := make([]byte, aead.NonceSize()+aead.Overhead())
for i := 0; i < 1000; i++ {
_, _ = rand.Read(buf)
_, err := aead.Open(buf[aead.NonceSize():aead.NonceSize()], buf[:aead.NonceSize()], buf[aead.NonceSize():], nil)
assert.NotEqual(t, err, nil)
}
}

View File

@@ -0,0 +1,16 @@
package aes128gcm
import (
"net"
)
func (c *Config) UDP() {
}
func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnServer(c, raw, first, leaveSize)
}

View File

@@ -0,0 +1,123 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.11
// protoc v3.21.12
// source: transport/internet/finalmask/mkcp/aes128gcm/config.proto
package aes128gcm
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type Config struct {
state protoimpl.MessageState `protogen:"open.v1"`
Password string `protobuf:"bytes,1,opt,name=password,proto3" json:"password,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Config) Reset() {
*x = Config{}
mi := &file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Config) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Config) ProtoMessage() {}
func (x *Config) ProtoReflect() protoreflect.Message {
mi := &file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Config.ProtoReflect.Descriptor instead.
func (*Config) Descriptor() ([]byte, []int) {
return file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDescGZIP(), []int{0}
}
func (x *Config) GetPassword() string {
if x != nil {
return x.Password
}
return ""
}
var File_transport_internet_finalmask_mkcp_aes128gcm_config_proto protoreflect.FileDescriptor
const file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDesc = "" +
"\n" +
"8transport/internet/finalmask/mkcp/aes128gcm/config.proto\x120xray.transport.internet.finalmask.mkcp.aes128gcm\"$\n" +
"\x06Config\x12\x1a\n" +
"\bpassword\x18\x01 \x01(\tR\bpasswordB\xc1\x01\n" +
"4com.xray.transport.internet.finalmask.mkcp.aes128gcmP\x01ZTgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/aes128gcm\xaa\x020Xray.Transport.Internet.Finalmask.Mkcp.Aes128Gcmb\x06proto3"
var (
file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDescOnce sync.Once
file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDescData []byte
)
func file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDescGZIP() []byte {
file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDescOnce.Do(func() {
file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDesc), len(file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDesc)))
})
return file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDescData
}
var file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_goTypes = []any{
(*Config)(nil), // 0: xray.transport.internet.finalmask.mkcp.aes128gcm.Config
}
var file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_init() }
func file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_init() {
if File_transport_internet_finalmask_mkcp_aes128gcm_config_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDesc), len(file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDesc)),
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_goTypes,
DependencyIndexes: file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_depIdxs,
MessageInfos: file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_msgTypes,
}.Build()
File_transport_internet_finalmask_mkcp_aes128gcm_config_proto = out.File
file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_goTypes = nil
file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_depIdxs = nil
}

View File

@@ -0,0 +1,11 @@
syntax = "proto3";
package xray.transport.internet.finalmask.mkcp.aes128gcm;
option csharp_namespace = "Xray.Transport.Internet.Finalmask.Mkcp.Aes128Gcm";
option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/aes128gcm";
option java_package = "com.xray.transport.internet.finalmask.mkcp.aes128gcm";
option java_multiple_files = true;
message Config {
string password = 1;
}

View File

@@ -0,0 +1,174 @@
package aes128gcm
import (
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"io"
"net"
sync "sync"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/common"
"github.com/amnezia-vpn/amnezia-xray-core/common/crypto"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
)
type aes128gcmConn struct {
first bool
leaveSize int32
conn net.PacketConn
aead cipher.AEAD
readBuf []byte
readMutex sync.Mutex
writeBuf []byte
writeMutex sync.Mutex
}
func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
hashedPsk := sha256.Sum256([]byte(c.Password))
conn := &aes128gcmConn{
first: first,
leaveSize: leaveSize,
conn: raw,
aead: crypto.NewAesGcm(hashedPsk[:16]),
}
if first {
conn.readBuf = make([]byte, 8192)
conn.writeBuf = make([]byte, 8192)
}
return conn, nil
}
func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *aes128gcmConn) Size() int32 {
return int32(c.aead.NonceSize()) + int32(c.aead.Overhead())
}
func (c *aes128gcmConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if c.first {
c.readMutex.Lock()
n, addr, err = c.conn.ReadFrom(c.readBuf)
if err != nil {
c.readMutex.Unlock()
return n, addr, err
}
if n < int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("aead").Base(io.ErrShortBuffer)
}
if len(p) < n-int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("aead").Base(io.ErrShortBuffer)
}
nonceSize := c.aead.NonceSize()
nonce := c.readBuf[:nonceSize]
ciphertext := c.readBuf[nonceSize:n]
_, err = c.aead.Open(p[:0], nonce, ciphertext, nil)
if err != nil {
c.readMutex.Unlock()
return 0, addr, errors.New("aead open").Base(err)
}
c.readMutex.Unlock()
return n - int(c.Size()), addr, nil
}
n, addr, err = c.conn.ReadFrom(p)
if err != nil {
return n, addr, err
}
if n < int(c.Size()) {
return 0, addr, errors.New("aead").Base(io.ErrShortBuffer)
}
nonceSize := c.aead.NonceSize()
nonce := p[:nonceSize]
ciphertext := p[nonceSize:n]
_, err = c.aead.Open(ciphertext[:0], nonce, ciphertext, nil)
if err != nil {
return 0, addr, errors.New("aead open").Base(err)
}
copy(p, p[nonceSize:n-c.aead.Overhead()])
return n - int(c.Size()), addr, nil
}
func (c *aes128gcmConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if c.first {
if c.leaveSize+c.Size()+int32(len(p)) > 8192 {
return 0, errors.New("too many masks")
}
c.writeMutex.Lock()
n = copy(c.writeBuf[c.leaveSize+int32(c.aead.NonceSize()):], p)
// n = copy(c.writeBuf[c.leaveSize+c.Size():], p)
n += int(c.leaveSize) + int(c.Size())
nonceSize := c.aead.NonceSize()
nonce := c.writeBuf[c.leaveSize : c.leaveSize+int32(nonceSize)]
common.Must2(rand.Read(nonce))
// copy(c.writeBuf[c.leaveSize+int32(nonceSize):], c.writeBuf[c.leaveSize+c.Size():n])
plaintext := c.writeBuf[c.leaveSize+int32(nonceSize) : n-c.aead.Overhead()]
_ = c.aead.Seal(plaintext[:0], nonce, plaintext, nil)
nn, err := c.conn.WriteTo(c.writeBuf[:n], addr)
if err != nil {
c.writeMutex.Unlock()
return 0, err
}
if nn != n {
c.writeMutex.Unlock()
return 0, errors.New("nn != n")
}
c.writeMutex.Unlock()
return len(p), nil
}
nonceSize := c.aead.NonceSize()
nonce := p[c.leaveSize : c.leaveSize+int32(nonceSize)]
common.Must2(rand.Read(nonce))
copy(p[c.leaveSize+int32(nonceSize):], p[c.leaveSize+c.Size():])
plaintext := p[c.leaveSize+int32(nonceSize) : len(p)-c.aead.Overhead()]
_ = c.aead.Seal(plaintext[:0], nonce, plaintext, nil)
return c.conn.WriteTo(p, addr)
}
func (c *aes128gcmConn) Close() error {
return c.conn.Close()
}
func (c *aes128gcmConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *aes128gcmConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *aes128gcmConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *aes128gcmConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}

View File

@@ -0,0 +1,16 @@
package original
import (
"net"
)
func (c *Config) UDP() {
}
func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnServer(c, raw, first, leaveSize)
}

View File

@@ -0,0 +1,114 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.11
// protoc v3.21.12
// source: transport/internet/finalmask/mkcp/original/config.proto
package original
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type Config struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Config) Reset() {
*x = Config{}
mi := &file_transport_internet_finalmask_mkcp_original_config_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Config) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Config) ProtoMessage() {}
func (x *Config) ProtoReflect() protoreflect.Message {
mi := &file_transport_internet_finalmask_mkcp_original_config_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Config.ProtoReflect.Descriptor instead.
func (*Config) Descriptor() ([]byte, []int) {
return file_transport_internet_finalmask_mkcp_original_config_proto_rawDescGZIP(), []int{0}
}
var File_transport_internet_finalmask_mkcp_original_config_proto protoreflect.FileDescriptor
const file_transport_internet_finalmask_mkcp_original_config_proto_rawDesc = "" +
"\n" +
"7transport/internet/finalmask/mkcp/original/config.proto\x12/xray.transport.internet.finalmask.mkcp.original\"\b\n" +
"\x06ConfigB\xbe\x01\n" +
"3com.xray.transport.internet.finalmask.mkcp.originalP\x01ZSgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/original\xaa\x02/Xray.Transport.Internet.Finalmask.Mkcp.Originalb\x06proto3"
var (
file_transport_internet_finalmask_mkcp_original_config_proto_rawDescOnce sync.Once
file_transport_internet_finalmask_mkcp_original_config_proto_rawDescData []byte
)
func file_transport_internet_finalmask_mkcp_original_config_proto_rawDescGZIP() []byte {
file_transport_internet_finalmask_mkcp_original_config_proto_rawDescOnce.Do(func() {
file_transport_internet_finalmask_mkcp_original_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_mkcp_original_config_proto_rawDesc), len(file_transport_internet_finalmask_mkcp_original_config_proto_rawDesc)))
})
return file_transport_internet_finalmask_mkcp_original_config_proto_rawDescData
}
var file_transport_internet_finalmask_mkcp_original_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_transport_internet_finalmask_mkcp_original_config_proto_goTypes = []any{
(*Config)(nil), // 0: xray.transport.internet.finalmask.mkcp.original.Config
}
var file_transport_internet_finalmask_mkcp_original_config_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_transport_internet_finalmask_mkcp_original_config_proto_init() }
func file_transport_internet_finalmask_mkcp_original_config_proto_init() {
if File_transport_internet_finalmask_mkcp_original_config_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_mkcp_original_config_proto_rawDesc), len(file_transport_internet_finalmask_mkcp_original_config_proto_rawDesc)),
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_transport_internet_finalmask_mkcp_original_config_proto_goTypes,
DependencyIndexes: file_transport_internet_finalmask_mkcp_original_config_proto_depIdxs,
MessageInfos: file_transport_internet_finalmask_mkcp_original_config_proto_msgTypes,
}.Build()
File_transport_internet_finalmask_mkcp_original_config_proto = out.File
file_transport_internet_finalmask_mkcp_original_config_proto_goTypes = nil
file_transport_internet_finalmask_mkcp_original_config_proto_depIdxs = nil
}

View File

@@ -0,0 +1,9 @@
syntax = "proto3";
package xray.transport.internet.finalmask.mkcp.original;
option csharp_namespace = "Xray.Transport.Internet.Finalmask.Mkcp.Original";
option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/original";
option java_package = "com.xray.transport.internet.finalmask.mkcp.original";
option java_multiple_files = true;
message Config {}

View File

@@ -0,0 +1,225 @@
package original
import (
"crypto/cipher"
"encoding/binary"
"hash/fnv"
"io"
"net"
sync "sync"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/common"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
)
type simple struct{}
func NewSimple() *simple {
return &simple{}
}
func (*simple) NonceSize() int {
return 0
}
func (*simple) Overhead() int {
return 6
}
func (a *simple) Seal(dst, nonce, plain, extra []byte) []byte {
dst = append(dst, 0, 0, 0, 0, 0, 0)
binary.BigEndian.PutUint16(dst[4:], uint16(len(plain)))
dst = append(dst, plain...)
fnvHash := fnv.New32a()
common.Must2(fnvHash.Write(dst[4:]))
fnvHash.Sum(dst[:0])
dstLen := len(dst)
xtra := 4 - dstLen%4
if xtra != 4 {
dst = append(dst, make([]byte, xtra)...)
}
xorfwd(dst)
if xtra != 4 {
dst = dst[:dstLen]
}
return dst
}
func (a *simple) Open(dst, nonce, cipherText, extra []byte) ([]byte, error) {
dst = append(dst, cipherText...)
dstLen := len(dst)
xtra := 4 - dstLen%4
if xtra != 4 {
dst = append(dst, make([]byte, xtra)...)
}
xorbkd(dst)
if xtra != 4 {
dst = dst[:dstLen]
}
fnvHash := fnv.New32a()
common.Must2(fnvHash.Write(dst[4:]))
if binary.BigEndian.Uint32(dst[:4]) != fnvHash.Sum32() {
return nil, errors.New("invalid auth")
}
length := binary.BigEndian.Uint16(dst[4:6])
if len(dst)-6 != int(length) {
return nil, errors.New("invalid auth")
}
return dst[6:], nil
}
type simpleConn struct {
first bool
leaveSize int32
conn net.PacketConn
aead cipher.AEAD
readBuf []byte
readMutex sync.Mutex
writeBuf []byte
writeMutex sync.Mutex
}
func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
conn := &simpleConn{
first: first,
leaveSize: leaveSize,
conn: raw,
aead: &simple{},
}
if first {
conn.readBuf = make([]byte, 8192)
conn.writeBuf = make([]byte, 8192)
}
return conn, nil
}
func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *simpleConn) Size() int32 {
return int32(c.aead.NonceSize()) + int32(c.aead.Overhead())
}
func (c *simpleConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if c.first {
c.readMutex.Lock()
n, addr, err = c.conn.ReadFrom(c.readBuf)
if err != nil {
c.readMutex.Unlock()
return n, addr, err
}
if n < int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("aead").Base(io.ErrShortBuffer)
}
if len(p) < n-int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("aead").Base(io.ErrShortBuffer)
}
ciphertext := c.readBuf[:n]
opened, err := c.aead.Open(nil, nil, ciphertext, nil)
if err != nil {
c.readMutex.Unlock()
return 0, addr, errors.New("aead open").Base(err)
}
copy(p, opened)
c.readMutex.Unlock()
return n - int(c.Size()), addr, nil
}
n, addr, err = c.conn.ReadFrom(p)
if err != nil {
return n, addr, err
}
if n < int(c.Size()) {
return 0, addr, errors.New("aead").Base(io.ErrShortBuffer)
}
ciphertext := p[:n]
opened, err := c.aead.Open(nil, nil, ciphertext, nil)
if err != nil {
c.readMutex.Unlock()
return 0, addr, errors.New("aead open").Base(err)
}
copy(p, opened)
return n - int(c.Size()), addr, nil
}
func (c *simpleConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if c.first {
if c.leaveSize+c.Size()+int32(len(p)) > 8192 {
return 0, errors.New("too many masks")
}
c.writeMutex.Lock()
n = copy(c.writeBuf[c.leaveSize+c.Size():], p)
n += int(c.leaveSize) + int(c.Size())
plaintext := c.writeBuf[c.leaveSize+c.Size() : n]
sealed := c.aead.Seal(nil, nil, plaintext, nil)
copy(c.writeBuf[c.leaveSize:], sealed)
nn, err := c.conn.WriteTo(c.writeBuf[:n], addr)
if err != nil {
c.writeMutex.Unlock()
return 0, err
}
if nn != n {
c.writeMutex.Unlock()
return 0, errors.New("nn != n")
}
c.writeMutex.Unlock()
return len(p), nil
}
plaintext := p[c.leaveSize+c.Size():]
sealed := c.aead.Seal(nil, nil, plaintext, nil)
copy(p[c.leaveSize:], sealed)
return c.conn.WriteTo(p, addr)
}
func (c *simpleConn) Close() error {
return c.conn.Close()
}
func (c *simpleConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *simpleConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *simpleConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *simpleConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}

View File

@@ -0,0 +1,19 @@
package original_test
import (
"crypto/rand"
"testing"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/original"
"github.com/stretchr/testify/assert"
)
func TestOriginalBounce(t *testing.T) {
aead := original.NewSimple()
buf := make([]byte, aead.NonceSize()+aead.Overhead())
for i := 0; i < 1000; i++ {
_, _ = rand.Read(buf)
_, err := aead.Open(buf[:0], nil, buf, nil)
assert.NotEqual(t, err, nil)
}
}

View File

@@ -1,7 +1,7 @@
//go:build !amd64
// +build !amd64
package kcp
package original
// xorfwd performs XOR forwards in words, x[i] ^= x[i-4], i from 0 to len
func xorfwd(x []byte) {

View File

@@ -1,4 +1,4 @@
package kcp
package original
//go:noescape
func xorfwd(x []byte)

View File

@@ -2,41 +2,15 @@ package salamander
import (
"net"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/salamander/obfs"
)
func (c *Config) UDP() {
}
func (c *Config) WrapConnClient(raw net.Conn) (net.Conn, error) {
return raw, nil
func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *Config) WrapConnServer(raw net.Conn) (net.Conn, error) {
return raw, nil
}
func (c *Config) WrapPacketConnClient(raw net.PacketConn) (net.PacketConn, error) {
ob, err := obfs.NewSalamanderObfuscator([]byte(c.Password))
if err != nil {
return nil, errors.New("salamander err").Base(err)
}
return obfs.WrapPacketConn(raw, ob), nil
}
func (c *Config) WrapPacketConnServer(raw net.PacketConn) (net.PacketConn, error) {
ob, err := obfs.NewSalamanderObfuscator([]byte(c.Password))
if err != nil {
return nil, errors.New("salamander err").Base(err)
}
return obfs.WrapPacketConn(raw, ob), nil
}
func (c *Config) Size() int {
return 0
}
func (c *Config) Serialize([]byte) {
func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnServer(c, raw, first, leaveSize)
}

View File

@@ -69,10 +69,10 @@ var File_transport_internet_finalmask_salamander_config_proto protoreflect.FileD
const file_transport_internet_finalmask_salamander_config_proto_rawDesc = "" +
"\n" +
"4transport/internet/finalmask/salamander/config.proto\x12*xray.transport.internet.udpmask.salamander\"$\n" +
"4transport/internet/finalmask/salamander/config.proto\x12,xray.transport.internet.finalmask.salamander\"$\n" +
"\x06Config\x12\x1a\n" +
"\bpassword\x18\x01 \x01(\tR\bpasswordB\xaf\x01\n" +
".com.xray.transport.internet.udpmask.salamanderP\x01ZNgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/udpmask/salamander\xaa\x02*Xray.Transport.Internet.Udpmask.Salamanderb\x06proto3"
"\bpassword\x18\x01 \x01(\tR\bpasswordB\xb5\x01\n" +
"0com.xray.transport.internet.finalmask.salamanderP\x01ZPgithub.com/amneiza-vpn/amnezia-xray-core/transport/internet/finalmask/salamander\xaa\x02,Xray.Transport.Internet.Finalmask.Salamanderb\x06proto3"
var (
file_transport_internet_finalmask_salamander_config_proto_rawDescOnce sync.Once
@@ -88,7 +88,7 @@ func file_transport_internet_finalmask_salamander_config_proto_rawDescGZIP() []b
var file_transport_internet_finalmask_salamander_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_transport_internet_finalmask_salamander_config_proto_goTypes = []any{
(*Config)(nil), // 0: xray.transport.internet.udpmask.salamander.Config
(*Config)(nil), // 0: xray.transport.internet.finalmask.salamander.Config
}
var file_transport_internet_finalmask_salamander_config_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type

View File

@@ -1,9 +1,9 @@
syntax = "proto3";
package xray.transport.internet.udpmask.salamander;
option csharp_namespace = "Xray.Transport.Internet.Udpmask.Salamander";
option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/udpmask/salamander";
option java_package = "com.xray.transport.internet.udpmask.salamander";
package xray.transport.internet.finalmask.salamander;
option csharp_namespace = "Xray.Transport.Internet.Finalmask.Salamander";
option go_package = "github.com/amneiza-vpn/amnezia-xray-core/transport/internet/finalmask/salamander";
option java_package = "com.xray.transport.internet.finalmask.salamander";
option java_multiple_files = true;
message Config {

View File

@@ -0,0 +1,147 @@
package salamander
import (
"io"
"net"
"sync"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
)
type obfsPacketConn struct {
first bool
leaveSize int32
conn net.PacketConn
obfs *SalamanderObfuscator
readBuf []byte
readMutex sync.Mutex
writeBuf []byte
writeMutex sync.Mutex
}
func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
ob, err := NewSalamanderObfuscator([]byte(c.Password))
if err != nil {
return nil, errors.New("salamander err").Base(err)
}
conn := &obfsPacketConn{
first: first,
leaveSize: leaveSize,
conn: raw,
obfs: ob,
}
if first {
conn.readBuf = make([]byte, 8192)
conn.writeBuf = make([]byte, 8192)
}
return conn, nil
}
func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) {
return NewConnClient(c, raw, first, leaveSize)
}
func (c *obfsPacketConn) Size() int32 {
return smSaltLen
}
func (c *obfsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
if c.first {
c.readMutex.Lock()
n, addr, err = c.conn.ReadFrom(c.readBuf)
if err != nil {
c.readMutex.Unlock()
return n, addr, err
}
if n < int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("salamander").Base(io.ErrShortBuffer)
}
if len(p) < n-int(c.Size()) {
c.readMutex.Unlock()
return 0, addr, errors.New("salamander").Base(io.ErrShortBuffer)
}
c.obfs.Deobfuscate(c.readBuf[:n], p)
c.readMutex.Unlock()
return n - int(c.Size()), addr, err
}
n, addr, err = c.conn.ReadFrom(p)
if err != nil {
return n, addr, err
}
if n < int(c.Size()) {
return 0, addr, errors.New("salamander").Base(io.ErrShortBuffer)
}
c.obfs.Deobfuscate(p[:n], p)
return n - int(c.Size()), addr, err
}
func (c *obfsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
if c.first {
if c.leaveSize+c.Size()+int32(len(p)) > 8192 {
return 0, errors.New("too many masks")
}
c.writeMutex.Lock()
n = copy(c.writeBuf[c.leaveSize+c.Size():], p)
n += int(c.leaveSize) + int(c.Size())
c.obfs.Obfuscate(c.writeBuf[c.leaveSize+c.Size():n], c.writeBuf[c.leaveSize:n])
nn, err := c.conn.WriteTo(c.writeBuf[:n], addr)
if err != nil {
c.writeMutex.Unlock()
return 0, err
}
if nn != n {
c.writeMutex.Unlock()
return 0, errors.New("nn != n")
}
c.writeMutex.Unlock()
return len(p), nil
}
c.obfs.Obfuscate(p[c.leaveSize+c.Size():], p[c.leaveSize:])
return c.conn.WriteTo(p, addr)
}
func (c *obfsPacketConn) Close() error {
return c.conn.Close()
}
func (c *obfsPacketConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *obfsPacketConn) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *obfsPacketConn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *obfsPacketConn) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}

View File

@@ -1,121 +0,0 @@
package obfs
import (
"net"
"sync"
"syscall"
"time"
)
const udpBufferSize = 2048 // QUIC packets are at most 1500 bytes long, so 2k should be more than enough
// Obfuscator is the interface that wraps the Obfuscate and Deobfuscate methods.
// Both methods return the number of bytes written to out.
// If a packet is not valid, the methods should return 0.
type Obfuscator interface {
Obfuscate(in, out []byte) int
Deobfuscate(in, out []byte) int
}
var _ net.PacketConn = (*obfsPacketConn)(nil)
type obfsPacketConn struct {
Conn net.PacketConn
Obfs Obfuscator
readBuf []byte
readMutex sync.Mutex
writeBuf []byte
writeMutex sync.Mutex
}
// obfsPacketConnUDP is a special case of obfsPacketConn that uses a UDPConn
// as the underlying connection. We pass additional methods to quic-go to
// enable UDP-specific optimizations.
type obfsPacketConnUDP struct {
*obfsPacketConn
UDPConn *net.UDPConn
}
// WrapPacketConn enables obfuscation on a net.PacketConn.
// The obfuscation is transparent to the caller - the n bytes returned by
// ReadFrom and WriteTo are the number of original bytes, not after
// obfuscation/deobfuscation.
func WrapPacketConn(conn net.PacketConn, obfs Obfuscator) net.PacketConn {
opc := &obfsPacketConn{
Conn: conn,
Obfs: obfs,
readBuf: make([]byte, udpBufferSize),
writeBuf: make([]byte, udpBufferSize),
}
if udpConn, ok := conn.(*net.UDPConn); ok {
return &obfsPacketConnUDP{
obfsPacketConn: opc,
UDPConn: udpConn,
}
} else {
return opc
}
}
func (c *obfsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
for {
c.readMutex.Lock()
n, addr, err = c.Conn.ReadFrom(c.readBuf)
if n <= 0 {
c.readMutex.Unlock()
return n, addr, err
}
n = c.Obfs.Deobfuscate(c.readBuf[:n], p)
c.readMutex.Unlock()
if n > 0 || err != nil {
return n, addr, err
}
// Invalid packet, try again
}
}
func (c *obfsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
c.writeMutex.Lock()
nn := c.Obfs.Obfuscate(p, c.writeBuf)
_, err = c.Conn.WriteTo(c.writeBuf[:nn], addr)
c.writeMutex.Unlock()
if err == nil {
n = len(p)
}
return n, err
}
func (c *obfsPacketConn) Close() error {
return c.Conn.Close()
}
func (c *obfsPacketConn) LocalAddr() net.Addr {
return c.Conn.LocalAddr()
}
func (c *obfsPacketConn) SetDeadline(t time.Time) error {
return c.Conn.SetDeadline(t)
}
func (c *obfsPacketConn) SetReadDeadline(t time.Time) error {
return c.Conn.SetReadDeadline(t)
}
func (c *obfsPacketConn) SetWriteDeadline(t time.Time) error {
return c.Conn.SetWriteDeadline(t)
}
// UDP-specific methods below
func (c *obfsPacketConnUDP) SetReadBuffer(bytes int) error {
return c.UDPConn.SetReadBuffer(bytes)
}
func (c *obfsPacketConnUDP) SetWriteBuffer(bytes int) error {
return c.UDPConn.SetWriteBuffer(bytes)
}
func (c *obfsPacketConnUDP) SyscallConn() (syscall.RawConn, error) {
return c.UDPConn.SyscallConn()
}

View File

@@ -1,45 +0,0 @@
package obfs
import (
"crypto/rand"
"testing"
"github.com/stretchr/testify/assert"
)
func BenchmarkSalamanderObfuscator_Obfuscate(b *testing.B) {
o, _ := NewSalamanderObfuscator([]byte("average_password"))
in := make([]byte, 1200)
_, _ = rand.Read(in)
out := make([]byte, 2048)
b.ResetTimer()
for i := 0; i < b.N; i++ {
o.Obfuscate(in, out)
}
}
func BenchmarkSalamanderObfuscator_Deobfuscate(b *testing.B) {
o, _ := NewSalamanderObfuscator([]byte("average_password"))
in := make([]byte, 1200)
_, _ = rand.Read(in)
out := make([]byte, 2048)
b.ResetTimer()
for i := 0; i < b.N; i++ {
o.Deobfuscate(in, out)
}
}
func TestSalamanderObfuscator(t *testing.T) {
o, _ := NewSalamanderObfuscator([]byte("average_password"))
in := make([]byte, 1200)
oOut := make([]byte, 2048)
dOut := make([]byte, 2048)
for i := 0; i < 1000; i++ {
_, _ = rand.Read(in)
n := o.Obfuscate(in, oOut)
assert.Equal(t, len(in)+smSaltLen, n)
n = o.Deobfuscate(oOut[:n], dOut)
assert.Equal(t, len(in), n)
assert.Equal(t, in, dOut[:n])
}
}

View File

@@ -1,4 +1,4 @@
package obfs
package salamander
import (
"fmt"
@@ -15,8 +15,6 @@ const (
smKeyLen = blake2b.Size256
)
var _ Obfuscator = (*SalamanderObfuscator)(nil)
var ErrPSKTooShort = fmt.Errorf("PSK must be at least %d bytes", smPSKMinLen)
// SalamanderObfuscator is an obfuscator that obfuscates each packet with

View File

@@ -0,0 +1,81 @@
package salamander_test
import (
"crypto/rand"
"testing"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/salamander"
"github.com/stretchr/testify/assert"
)
const (
smSaltLen = 8
)
func BenchmarkSalamanderObfuscator_Obfuscate(b *testing.B) {
o, _ := salamander.NewSalamanderObfuscator([]byte("average_password"))
in := make([]byte, 1200)
_, _ = rand.Read(in)
out := make([]byte, 2048)
b.ResetTimer()
for i := 0; i < b.N; i++ {
o.Obfuscate(in, out)
}
}
func BenchmarkSalamanderObfuscator_Deobfuscate(b *testing.B) {
o, _ := salamander.NewSalamanderObfuscator([]byte("average_password"))
in := make([]byte, 1200)
_, _ = rand.Read(in)
out := make([]byte, 2048)
b.ResetTimer()
for i := 0; i < b.N; i++ {
o.Deobfuscate(in, out)
}
}
func TestSalamanderObfuscator(t *testing.T) {
o, _ := salamander.NewSalamanderObfuscator([]byte("average_password"))
in := make([]byte, 1200)
oOut := make([]byte, 2048)
dOut := make([]byte, 2048)
for i := 0; i < 1000; i++ {
_, _ = rand.Read(in)
n := o.Obfuscate(in, oOut)
assert.Equal(t, len(in)+smSaltLen, n)
n = o.Deobfuscate(oOut[:n], dOut)
assert.Equal(t, len(in), n)
assert.Equal(t, in, dOut[:n])
}
}
func TestSalamanderInPlace(t *testing.T) {
o, _ := salamander.NewSalamanderObfuscator([]byte("average_password"))
in := make([]byte, 1200)
out := make([]byte, 2048)
_, _ = rand.Read(in)
o.Obfuscate(in, out)
out2 := make([]byte, 2048)
copy(out2[smSaltLen:], in)
o.Obfuscate(out2[smSaltLen:], out2)
dOut := make([]byte, 2048)
o.Deobfuscate(out, dOut)
o.Deobfuscate(out2, out2)
assert.Equal(t, in, dOut[:1200])
assert.Equal(t, in, out2[:1200])
}
func TestSalamanderBounce(t *testing.T) {
o, _ := salamander.NewSalamanderObfuscator([]byte("average_password"))
buf := make([]byte, 8)
for i := 0; i < 1000; i++ {
_, _ = rand.Read(buf)
n := o.Deobfuscate(buf, buf)
assert.Equal(t, 0, n)
}
}

View File

@@ -0,0 +1,129 @@
package finalmask_test
import (
"bytes"
"net"
"testing"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/dns"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/srtp"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/utp"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/wechat"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/wireguard"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/aes128gcm"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/original"
"github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/salamander"
)
func mustSendRecv(
t *testing.T,
from net.PacketConn,
to net.PacketConn,
msg []byte,
) {
t.Helper()
go func() {
_, err := from.WriteTo(msg, to.LocalAddr())
if err != nil {
t.Error(err)
}
}()
buf := make([]byte, 1024)
n, _, err := to.ReadFrom(buf)
if err != nil {
t.Fatal(err)
}
if n != len(msg) {
t.Fatalf("unexpected size: %d", n)
}
if !bytes.Equal(buf[:n], msg) {
t.Fatalf("unexpected data")
}
}
type layerMask struct {
name string
mask finalmask.Udpmask
}
func TestPacketConnReadWrite(t *testing.T) {
cases := []layerMask{
{
name: "aes128gcm",
mask: &aes128gcm.Config{Password: "123"},
},
{
name: "original",
mask: &original.Config{},
},
{
name: "dns",
mask: &dns.Config{Domain: "www.baidu.com"},
},
{
name: "srtp",
mask: &srtp.Config{},
},
{
name: "utp",
mask: &utp.Config{},
},
{
name: "wechat",
mask: &wechat.Config{},
},
{
name: "wireguard",
mask: &wireguard.Config{},
},
{
name: "salamander",
mask: &salamander.Config{Password: "1234"},
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
mask := c.mask
maskManager := finalmask.NewUdpmaskManager([]finalmask.Udpmask{mask, mask})
client, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer client.Close()
client, err = maskManager.WrapPacketConnClient(client)
if err != nil {
t.Fatal(err)
}
server, err := net.ListenPacket("udp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer server.Close()
server, err = maskManager.WrapPacketConnServer(server)
if err != nil {
t.Fatal(err)
}
_ = client.SetDeadline(time.Now().Add(time.Second))
_ = server.SetDeadline(time.Now().Add(time.Second))
mustSendRecv(t, client, server, []byte("client -> server"))
mustSendRecv(t, server, client, []byte("server -> client"))
mustSendRecv(t, client, server, []byte{})
mustSendRecv(t, server, client, []byte{})
})
}
}

View File

@@ -0,0 +1,373 @@
package xdns
import (
"bytes"
"context"
"crypto/rand"
"encoding/base32"
"encoding/binary"
"io"
"net"
"sync"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
)
const (
numPadding = 3
numPaddingForPoll = 8
initPollDelay = 500 * time.Millisecond
maxPollDelay = 10 * time.Second
pollDelayMultiplier = 2.0
pollLimit = 16
)
var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
type packet struct {
p []byte
addr net.Addr
}
type xdnsConnClient struct {
conn net.PacketConn
clientID []byte
domain Name
pollChan chan struct{}
readQueue chan *packet
writeQueue chan *packet
closed bool
mutex sync.Mutex
}
func NewConnClient(c *Config, raw net.PacketConn, end bool) (net.PacketConn, error) {
if !end {
return nil, errors.New("xdns requires being at the outermost level")
}
domain, err := ParseName(c.Domain)
if err != nil {
return nil, err
}
conn := &xdnsConnClient{
conn: raw,
clientID: make([]byte, 8),
domain: domain,
pollChan: make(chan struct{}, pollLimit),
readQueue: make(chan *packet, 128),
writeQueue: make(chan *packet, 128),
}
rand.Read(conn.clientID)
go conn.recvLoop()
go conn.sendLoop()
return conn, nil
}
func (c *xdnsConnClient) recvLoop() {
for {
if c.closed {
break
}
var buf [4096]byte
n, addr, err := c.conn.ReadFrom(buf[:])
if err != nil {
continue
}
resp, err := MessageFromWireFormat(buf[:n])
if err != nil {
continue
}
payload := dnsResponsePayload(&resp, c.domain)
r := bytes.NewReader(payload)
anyPacket := false
for {
p, err := nextPacket(r)
if err != nil {
break
}
anyPacket = true
buf := make([]byte, len(p))
copy(buf, p)
select {
case c.readQueue <- &packet{
p: buf,
addr: addr,
}:
default:
}
}
if anyPacket {
select {
case c.pollChan <- struct{}{}:
default:
}
}
}
close(c.pollChan)
close(c.readQueue)
}
func (c *xdnsConnClient) sendLoop() {
var addr net.Addr
pollDelay := initPollDelay
pollTimer := time.NewTimer(pollDelay)
for {
var p *packet
pollTimerExpired := false
select {
case p = <-c.writeQueue:
default:
select {
case p = <-c.writeQueue:
case <-c.pollChan:
case <-pollTimer.C:
pollTimerExpired = true
}
}
if p != nil {
addr = p.addr
select {
case <-c.pollChan:
default:
}
} else if addr != nil {
encoded, _ := encode(nil, c.clientID, c.domain)
p = &packet{
p: encoded,
addr: addr,
}
}
if pollTimerExpired {
pollDelay = time.Duration(float64(pollDelay) * pollDelayMultiplier)
if pollDelay > maxPollDelay {
pollDelay = maxPollDelay
}
} else {
if !pollTimer.Stop() {
<-pollTimer.C
}
pollDelay = initPollDelay
}
pollTimer.Reset(pollDelay)
if c.closed {
return
}
if p != nil {
_, _ = c.conn.WriteTo(p.p, p.addr)
}
}
}
func (c *xdnsConnClient) Size() int32 {
return 0
}
func (c *xdnsConnClient) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
packet, ok := <-c.readQueue
if !ok {
return 0, nil, io.EOF
}
n = copy(p, packet.p)
if n != len(packet.p) {
return 0, nil, io.ErrShortBuffer
}
return n, packet.addr, nil
}
func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed {
return 0, errors.New("xdns closed")
}
encoded, err := encode(p, c.clientID, c.domain)
if err != nil {
errors.LogDebug(context.Background(), "xdns encode err ", err)
return 0, errors.New("xdns encode").Base(err)
}
select {
case c.writeQueue <- &packet{
p: encoded,
addr: addr,
}:
return len(p), nil
default:
return 0, errors.New("xdns queue full")
}
}
func (c *xdnsConnClient) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed {
return nil
}
c.closed = true
close(c.writeQueue)
return c.conn.Close()
}
func (c *xdnsConnClient) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *xdnsConnClient) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *xdnsConnClient) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *xdnsConnClient) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func encode(p []byte, clientID []byte, domain Name) ([]byte, error) {
var decoded []byte
{
if len(p) >= 224 {
return nil, errors.New("too long")
}
var buf bytes.Buffer
buf.Write(clientID[:])
n := numPadding
if len(p) == 0 {
n = numPaddingForPoll
}
buf.WriteByte(byte(224 + n))
_, _ = io.CopyN(&buf, rand.Reader, int64(n))
if len(p) > 0 {
buf.WriteByte(byte(len(p)))
buf.Write(p)
}
decoded = buf.Bytes()
}
encoded := make([]byte, base32Encoding.EncodedLen(len(decoded)))
base32Encoding.Encode(encoded, decoded)
encoded = bytes.ToLower(encoded)
labels := chunks(encoded, 63)
labels = append(labels, domain...)
name, err := NewName(labels)
if err != nil {
return nil, err
}
var id uint16
_ = binary.Read(rand.Reader, binary.BigEndian, &id)
query := &Message{
ID: id,
Flags: 0x0100,
Question: []Question{
{
Name: name,
Type: RRTypeTXT,
Class: ClassIN,
},
},
Additional: []RR{
{
Name: Name{},
Type: RRTypeOPT,
Class: 4096,
TTL: 0,
Data: []byte{},
},
},
}
buf, err := query.WireFormat()
if err != nil {
return nil, err
}
return buf, nil
}
func chunks(p []byte, n int) [][]byte {
var result [][]byte
for len(p) > 0 {
sz := len(p)
if sz > n {
sz = n
}
result = append(result, p[:sz])
p = p[sz:]
}
return result
}
func nextPacket(r *bytes.Reader) ([]byte, error) {
var n uint16
err := binary.Read(r, binary.BigEndian, &n)
if err != nil {
return nil, err
}
p := make([]byte, n)
_, err = io.ReadFull(r, p)
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return p, err
}
func dnsResponsePayload(resp *Message, domain Name) []byte {
if resp.Flags&0x8000 != 0x8000 {
return nil
}
if resp.Flags&0x000f != RcodeNoError {
return nil
}
if len(resp.Answer) != 1 {
return nil
}
answer := resp.Answer[0]
_, ok := answer.Name.TrimSuffix(domain)
if !ok {
return nil
}
if answer.Type != RRTypeTXT {
return nil
}
payload, err := DecodeRDataTXT(answer.Data)
if err != nil {
return nil
}
return payload
}

View File

@@ -0,0 +1,16 @@
package xdns
import (
"net"
)
func (c *Config) UDP() {
}
func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnClient(c, raw, end)
}
func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnServer(c, raw, end)
}

View File

@@ -0,0 +1,123 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.11
// protoc v3.21.12
// source: transport/internet/finalmask/xdns/config.proto
package xdns
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type Config struct {
state protoimpl.MessageState `protogen:"open.v1"`
Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Config) Reset() {
*x = Config{}
mi := &file_transport_internet_finalmask_xdns_config_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Config) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Config) ProtoMessage() {}
func (x *Config) ProtoReflect() protoreflect.Message {
mi := &file_transport_internet_finalmask_xdns_config_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Config.ProtoReflect.Descriptor instead.
func (*Config) Descriptor() ([]byte, []int) {
return file_transport_internet_finalmask_xdns_config_proto_rawDescGZIP(), []int{0}
}
func (x *Config) GetDomain() string {
if x != nil {
return x.Domain
}
return ""
}
var File_transport_internet_finalmask_xdns_config_proto protoreflect.FileDescriptor
const file_transport_internet_finalmask_xdns_config_proto_rawDesc = "" +
"\n" +
".transport/internet/finalmask/xdns/config.proto\x12&xray.transport.internet.finalmask.xdns\" \n" +
"\x06Config\x12\x16\n" +
"\x06domain\x18\x01 \x01(\tR\x06domainB\xa3\x01\n" +
"*com.xray.transport.internet.finalmask.xdnsP\x01ZJgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/xdns\xaa\x02&Xray.Transport.Internet.Finalmask.Xdnsb\x06proto3"
var (
file_transport_internet_finalmask_xdns_config_proto_rawDescOnce sync.Once
file_transport_internet_finalmask_xdns_config_proto_rawDescData []byte
)
func file_transport_internet_finalmask_xdns_config_proto_rawDescGZIP() []byte {
file_transport_internet_finalmask_xdns_config_proto_rawDescOnce.Do(func() {
file_transport_internet_finalmask_xdns_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_xdns_config_proto_rawDesc), len(file_transport_internet_finalmask_xdns_config_proto_rawDesc)))
})
return file_transport_internet_finalmask_xdns_config_proto_rawDescData
}
var file_transport_internet_finalmask_xdns_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_transport_internet_finalmask_xdns_config_proto_goTypes = []any{
(*Config)(nil), // 0: xray.transport.internet.finalmask.xdns.Config
}
var file_transport_internet_finalmask_xdns_config_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_transport_internet_finalmask_xdns_config_proto_init() }
func file_transport_internet_finalmask_xdns_config_proto_init() {
if File_transport_internet_finalmask_xdns_config_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_xdns_config_proto_rawDesc), len(file_transport_internet_finalmask_xdns_config_proto_rawDesc)),
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_transport_internet_finalmask_xdns_config_proto_goTypes,
DependencyIndexes: file_transport_internet_finalmask_xdns_config_proto_depIdxs,
MessageInfos: file_transport_internet_finalmask_xdns_config_proto_msgTypes,
}.Build()
File_transport_internet_finalmask_xdns_config_proto = out.File
file_transport_internet_finalmask_xdns_config_proto_goTypes = nil
file_transport_internet_finalmask_xdns_config_proto_depIdxs = nil
}

View File

@@ -0,0 +1,12 @@
syntax = "proto3";
package xray.transport.internet.finalmask.xdns;
option csharp_namespace = "Xray.Transport.Internet.Finalmask.Xdns";
option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/xdns";
option java_package = "com.xray.transport.internet.finalmask.xdns";
option java_multiple_files = true;
message Config {
string domain = 1;
}

View File

@@ -0,0 +1,575 @@
// Package dns deals with encoding and decoding DNS wire format.
package xdns
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"strings"
)
// The maximum number of DNS name compression pointers we are willing to follow.
// Without something like this, infinite loops are possible.
const compressionPointerLimit = 10
var (
// ErrZeroLengthLabel is the error returned for names that contain a
// zero-length label, like "example..com".
ErrZeroLengthLabel = errors.New("name contains a zero-length label")
// ErrLabelTooLong is the error returned for labels that are longer than
// 63 octets.
ErrLabelTooLong = errors.New("name contains a label longer than 63 octets")
// ErrNameTooLong is the error returned for names whose encoded
// representation is longer than 255 octets.
ErrNameTooLong = errors.New("name is longer than 255 octets")
// ErrReservedLabelType is the error returned when reading a label type
// prefix whose two most significant bits are not 00 or 11.
ErrReservedLabelType = errors.New("reserved label type")
// ErrTooManyPointers is the error returned when reading a compressed
// name that has too many compression pointers.
ErrTooManyPointers = errors.New("too many compression pointers")
// ErrTrailingBytes is the error returned when bytes remain in the parse
// buffer after parsing a message.
ErrTrailingBytes = errors.New("trailing bytes after message")
// ErrIntegerOverflow is the error returned when trying to encode an
// integer greater than 65535 into a 16-bit field.
ErrIntegerOverflow = errors.New("integer overflow")
)
const (
// https://tools.ietf.org/html/rfc1035#section-3.2.2
RRTypeTXT = 16
// https://tools.ietf.org/html/rfc6891#section-6.1.1
RRTypeOPT = 41
// https://tools.ietf.org/html/rfc1035#section-3.2.4
ClassIN = 1
// https://tools.ietf.org/html/rfc1035#section-4.1.1
RcodeNoError = 0 // a.k.a. NOERROR
RcodeFormatError = 1 // a.k.a. FORMERR
RcodeNameError = 3 // a.k.a. NXDOMAIN
RcodeNotImplemented = 4 // a.k.a. NOTIMPL
// https://tools.ietf.org/html/rfc6891#section-9
ExtendedRcodeBadVers = 16 // a.k.a. BADVERS
)
// Name represents a domain name, a sequence of labels each of which is 63
// octets or less in length.
//
// https://tools.ietf.org/html/rfc1035#section-3.1
type Name [][]byte
// NewName returns a Name from a slice of labels, after checking the labels for
// validity. Does not include a zero-length label at the end of the slice.
func NewName(labels [][]byte) (Name, error) {
name := Name(labels)
// https://tools.ietf.org/html/rfc1035#section-2.3.4
// Various objects and parameters in the DNS have size limits.
// labels 63 octets or less
// names 255 octets or less
for _, label := range labels {
if len(label) == 0 {
return nil, ErrZeroLengthLabel
}
if len(label) > 63 {
return nil, ErrLabelTooLong
}
}
// Check the total length.
builder := newMessageBuilder()
builder.WriteName(name)
if len(builder.Bytes()) > 255 {
return nil, ErrNameTooLong
}
return name, nil
}
// ParseName returns a new Name from a string of labels separated by dots, after
// checking the name for validity. A single dot at the end of the string is
// ignored.
func ParseName(s string) (Name, error) {
b := bytes.TrimSuffix([]byte(s), []byte("."))
if len(b) == 0 {
// bytes.Split(b, ".") would return [""] in this case
return NewName([][]byte{})
} else {
return NewName(bytes.Split(b, []byte(".")))
}
}
// String returns a reversible string representation of name. Labels are
// separated by dots, and any bytes in a label that are outside the set
// [0-9A-Za-z-] are replaced with a \xXX hex escape sequence.
func (name Name) String() string {
if len(name) == 0 {
return "."
}
var buf strings.Builder
for i, label := range name {
if i > 0 {
buf.WriteByte('.')
}
for _, b := range label {
if b == '-' ||
('0' <= b && b <= '9') ||
('A' <= b && b <= 'Z') ||
('a' <= b && b <= 'z') {
buf.WriteByte(b)
} else {
fmt.Fprintf(&buf, "\\x%02x", b)
}
}
}
return buf.String()
}
// TrimSuffix returns a Name with the given suffix removed, if it was present.
// The second return value indicates whether the suffix was present. If the
// suffix was not present, the first return value is nil.
func (name Name) TrimSuffix(suffix Name) (Name, bool) {
if len(name) < len(suffix) {
return nil, false
}
split := len(name) - len(suffix)
fore, aft := name[:split], name[split:]
for i := 0; i < len(aft); i++ {
if !bytes.Equal(bytes.ToLower(aft[i]), bytes.ToLower(suffix[i])) {
return nil, false
}
}
return fore, true
}
// Message represents a DNS message.
//
// https://tools.ietf.org/html/rfc1035#section-4.1
type Message struct {
ID uint16
Flags uint16
Question []Question
Answer []RR
Authority []RR
Additional []RR
}
// Opcode extracts the OPCODE part of the Flags field.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.1
func (message *Message) Opcode() uint16 {
return (message.Flags >> 11) & 0xf
}
// Rcode extracts the RCODE part of the Flags field.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.1
func (message *Message) Rcode() uint16 {
return message.Flags & 0x000f
}
// Question represents an entry in the question section of a message.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.2
type Question struct {
Name Name
Type uint16
Class uint16
}
// RR represents a resource record.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.3
type RR struct {
Name Name
Type uint16
Class uint16
TTL uint32
Data []byte
}
// readName parses a DNS name from r. It leaves r positioned just after the
// parsed name.
func readName(r io.ReadSeeker) (Name, error) {
var labels [][]byte
// We limit the number of compression pointers we are willing to follow.
numPointers := 0
// If we followed any compression pointers, we must finally seek to just
// past the first pointer.
var seekTo int64
loop:
for {
var labelType byte
err := binary.Read(r, binary.BigEndian, &labelType)
if err != nil {
return nil, err
}
switch labelType & 0xc0 {
case 0x00:
// This is an ordinary label.
// https://tools.ietf.org/html/rfc1035#section-3.1
length := int(labelType & 0x3f)
if length == 0 {
break loop
}
label := make([]byte, length)
_, err := io.ReadFull(r, label)
if err != nil {
return nil, err
}
labels = append(labels, label)
case 0xc0:
// This is a compression pointer.
// https://tools.ietf.org/html/rfc1035#section-4.1.4
upper := labelType & 0x3f
var lower byte
err := binary.Read(r, binary.BigEndian, &lower)
if err != nil {
return nil, err
}
offset := (uint16(upper) << 8) | uint16(lower)
if numPointers == 0 {
// The first time we encounter a pointer,
// remember our position so we can seek back to
// it when done.
seekTo, err = r.Seek(0, io.SeekCurrent)
if err != nil {
return nil, err
}
}
numPointers++
if numPointers > compressionPointerLimit {
return nil, ErrTooManyPointers
}
// Follow the pointer and continue.
_, err = r.Seek(int64(offset), io.SeekStart)
if err != nil {
return nil, err
}
default:
// "The 10 and 01 combinations are reserved for future
// use."
return nil, ErrReservedLabelType
}
}
// If we followed any pointers, then seek back to just after the first
// one.
if numPointers > 0 {
_, err := r.Seek(seekTo, io.SeekStart)
if err != nil {
return nil, err
}
}
return NewName(labels)
}
// readQuestion parses one entry from the Question section. It leaves r
// positioned just after the parsed entry.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.2
func readQuestion(r io.ReadSeeker) (Question, error) {
var question Question
var err error
question.Name, err = readName(r)
if err != nil {
return question, err
}
for _, ptr := range []*uint16{&question.Type, &question.Class} {
err := binary.Read(r, binary.BigEndian, ptr)
if err != nil {
return question, err
}
}
return question, nil
}
// readRR parses one resource record. It leaves r positioned just after the
// parsed resource record.
//
// https://tools.ietf.org/html/rfc1035#section-4.1.3
func readRR(r io.ReadSeeker) (RR, error) {
var rr RR
var err error
rr.Name, err = readName(r)
if err != nil {
return rr, err
}
for _, ptr := range []*uint16{&rr.Type, &rr.Class} {
err := binary.Read(r, binary.BigEndian, ptr)
if err != nil {
return rr, err
}
}
err = binary.Read(r, binary.BigEndian, &rr.TTL)
if err != nil {
return rr, err
}
var rdLength uint16
err = binary.Read(r, binary.BigEndian, &rdLength)
if err != nil {
return rr, err
}
rr.Data = make([]byte, rdLength)
_, err = io.ReadFull(r, rr.Data)
if err != nil {
return rr, err
}
return rr, nil
}
// readMessage parses a complete DNS message. It leaves r positioned just after
// the parsed message.
func readMessage(r io.ReadSeeker) (Message, error) {
var message Message
// Header section
// https://tools.ietf.org/html/rfc1035#section-4.1.1
var qdCount, anCount, nsCount, arCount uint16
for _, ptr := range []*uint16{
&message.ID, &message.Flags,
&qdCount, &anCount, &nsCount, &arCount,
} {
err := binary.Read(r, binary.BigEndian, ptr)
if err != nil {
return message, err
}
}
// Question section
// https://tools.ietf.org/html/rfc1035#section-4.1.2
for i := 0; i < int(qdCount); i++ {
question, err := readQuestion(r)
if err != nil {
return message, err
}
message.Question = append(message.Question, question)
}
// Answer, Authority, and Additional sections
// https://tools.ietf.org/html/rfc1035#section-4.1.3
for _, rec := range []struct {
ptr *[]RR
count uint16
}{
{&message.Answer, anCount},
{&message.Authority, nsCount},
{&message.Additional, arCount},
} {
for i := 0; i < int(rec.count); i++ {
rr, err := readRR(r)
if err != nil {
return message, err
}
*rec.ptr = append(*rec.ptr, rr)
}
}
return message, nil
}
// MessageFromWireFormat parses a message from buf and returns a Message object.
// It returns ErrTrailingBytes if there are bytes remaining in buf after parsing
// is done.
func MessageFromWireFormat(buf []byte) (Message, error) {
r := bytes.NewReader(buf)
message, err := readMessage(r)
if err == io.EOF {
err = io.ErrUnexpectedEOF
} else if err == nil {
// Check for trailing bytes.
_, err = r.ReadByte()
if err == io.EOF {
err = nil
} else if err == nil {
err = ErrTrailingBytes
}
}
return message, err
}
// messageBuilder manages the state of serializing a DNS message. Its main
// function is to keep track of names already written for the purpose of name
// compression.
type messageBuilder struct {
w bytes.Buffer
nameCache map[string]int
}
// newMessageBuilder creates a new messageBuilder with an empty name cache.
func newMessageBuilder() *messageBuilder {
return &messageBuilder{
nameCache: make(map[string]int),
}
}
// Bytes returns the serialized DNS message as a slice of bytes.
func (builder *messageBuilder) Bytes() []byte {
return builder.w.Bytes()
}
// WriteName appends name to the in-progress messageBuilder, employing
// compression pointers to previously written names if possible.
func (builder *messageBuilder) WriteName(name Name) {
// https://tools.ietf.org/html/rfc1035#section-3.1
for i := range name {
// Has this suffix already been encoded in the message?
if ptr, ok := builder.nameCache[name[i:].String()]; ok && ptr&0x3fff == ptr {
// If so, we can write a compression pointer.
binary.Write(&builder.w, binary.BigEndian, uint16(0xc000|ptr))
return
}
// Not cached; we must encode this label verbatim. Store a cache
// entry pointing to the beginning of it.
builder.nameCache[name[i:].String()] = builder.w.Len()
length := len(name[i])
if length == 0 || length > 63 {
panic(length)
}
builder.w.WriteByte(byte(length))
builder.w.Write(name[i])
}
builder.w.WriteByte(0)
}
// WriteQuestion appends a Question section entry to the in-progress
// messageBuilder.
func (builder *messageBuilder) WriteQuestion(question *Question) {
// https://tools.ietf.org/html/rfc1035#section-4.1.2
builder.WriteName(question.Name)
binary.Write(&builder.w, binary.BigEndian, question.Type)
binary.Write(&builder.w, binary.BigEndian, question.Class)
}
// WriteRR appends a resource record to the in-progress messageBuilder. It
// returns ErrIntegerOverflow if the length of rr.Data does not fit in 16 bits.
func (builder *messageBuilder) WriteRR(rr *RR) error {
// https://tools.ietf.org/html/rfc1035#section-4.1.3
builder.WriteName(rr.Name)
binary.Write(&builder.w, binary.BigEndian, rr.Type)
binary.Write(&builder.w, binary.BigEndian, rr.Class)
binary.Write(&builder.w, binary.BigEndian, rr.TTL)
rdLength := uint16(len(rr.Data))
if int(rdLength) != len(rr.Data) {
return ErrIntegerOverflow
}
binary.Write(&builder.w, binary.BigEndian, rdLength)
builder.w.Write(rr.Data)
return nil
}
// WriteMessage appends a complete DNS message to the in-progress
// messageBuilder. It returns ErrIntegerOverflow if the number of entries in any
// section, or the length of the data in any resource record, does not fit in 16
// bits.
func (builder *messageBuilder) WriteMessage(message *Message) error {
// Header section
// https://tools.ietf.org/html/rfc1035#section-4.1.1
binary.Write(&builder.w, binary.BigEndian, message.ID)
binary.Write(&builder.w, binary.BigEndian, message.Flags)
for _, count := range []int{
len(message.Question),
len(message.Answer),
len(message.Authority),
len(message.Additional),
} {
count16 := uint16(count)
if int(count16) != count {
return ErrIntegerOverflow
}
binary.Write(&builder.w, binary.BigEndian, count16)
}
// Question section
// https://tools.ietf.org/html/rfc1035#section-4.1.2
for _, question := range message.Question {
builder.WriteQuestion(&question)
}
// Answer, Authority, and Additional sections
// https://tools.ietf.org/html/rfc1035#section-4.1.3
for _, rrs := range [][]RR{message.Answer, message.Authority, message.Additional} {
for _, rr := range rrs {
err := builder.WriteRR(&rr)
if err != nil {
return err
}
}
}
return nil
}
// WireFormat encodes a Message as a slice of bytes in DNS wire format. It
// returns ErrIntegerOverflow if the number of entries in any section, or the
// length of the data in any resource record, does not fit in 16 bits.
func (message *Message) WireFormat() ([]byte, error) {
builder := newMessageBuilder()
err := builder.WriteMessage(message)
if err != nil {
return nil, err
}
return builder.Bytes(), nil
}
// DecodeRDataTXT decodes TXT-DATA (as found in the RDATA for a resource record
// with TYPE=TXT) as a raw byte slice, by concatenating all the
// <character-string>s it contains.
//
// https://tools.ietf.org/html/rfc1035#section-3.3.14
func DecodeRDataTXT(p []byte) ([]byte, error) {
var buf bytes.Buffer
for {
if len(p) == 0 {
return nil, io.ErrUnexpectedEOF
}
n := int(p[0])
p = p[1:]
if len(p) < n {
return nil, io.ErrUnexpectedEOF
}
buf.Write(p[:n])
p = p[n:]
if len(p) == 0 {
break
}
}
return buf.Bytes(), nil
}
// EncodeRDataTXT encodes a slice of bytes as TXT-DATA, as appropriate for the
// RDATA of a resource record with TYPE=TXT. No length restriction is enforced
// here; that must be checked at a higher level.
//
// https://tools.ietf.org/html/rfc1035#section-3.3.14
func EncodeRDataTXT(p []byte) []byte {
// https://tools.ietf.org/html/rfc1035#section-3.3
// https://tools.ietf.org/html/rfc1035#section-3.3.14
// TXT data is a sequence of one or more <character-string>s, where
// <character-string> is a length octet followed by that number of
// octets.
var buf bytes.Buffer
for len(p) > 255 {
buf.WriteByte(255)
buf.Write(p[:255])
p = p[255:]
}
// Must write here, even if len(p) == 0, because it's "*one or more*
// <character-string>s".
buf.WriteByte(byte(len(p)))
buf.Write(p)
return buf.Bytes()
}

View File

@@ -0,0 +1,592 @@
package xdns
import (
"bytes"
"fmt"
"io"
"strconv"
"strings"
"testing"
)
func namesEqual(a, b Name) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
if !bytes.Equal(a[i], b[i]) {
return false
}
}
return true
}
func TestName(t *testing.T) {
for _, test := range []struct {
labels [][]byte
err error
s string
}{
{[][]byte{}, nil, "."},
{[][]byte{[]byte("test")}, nil, "test"},
{[][]byte{[]byte("a"), []byte("b"), []byte("c")}, nil, "a.b.c"},
{[][]byte{{}}, ErrZeroLengthLabel, ""},
{[][]byte{[]byte("a"), {}, []byte("c")}, ErrZeroLengthLabel, ""},
// 63 octets.
{[][]byte{[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE")}, nil,
"0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"},
// 64 octets.
{[][]byte{[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDEF")}, ErrLabelTooLong, ""},
// 64+64+64+62 octets.
{[][]byte{
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABC"),
}, nil,
"0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE.0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE.0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE.0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABC"},
// 64+64+64+63 octets.
{[][]byte{
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"),
[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCD"),
}, ErrNameTooLong, ""},
// 127 one-octet labels.
{[][]byte{
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'},
}, nil,
"0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E.F.0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E.F.0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E.F.0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E"},
// 128 one-octet labels.
{[][]byte{
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'},
{'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'},
}, ErrNameTooLong, ""},
} {
// Test that NewName returns proper error codes, and otherwise
// returns an equal slice of labels.
name, err := NewName(test.labels)
if err != test.err || (err == nil && !namesEqual(name, test.labels)) {
t.Errorf("%+q returned (%+q, %v), expected (%+q, %v)",
test.labels, name, err, test.labels, test.err)
continue
}
if test.err != nil {
continue
}
// Test that the string version of the name comes out as
// expected.
s := name.String()
if s != test.s {
t.Errorf("%+q became string %+q, expected %+q", test.labels, s, test.s)
continue
}
// Test that parsing from a string back to a Name results in the
// original slice of labels.
name, err = ParseName(s)
if err != nil || !namesEqual(name, test.labels) {
t.Errorf("%+q parsing %+q returned (%+q, %v), expected (%+q, %v)",
test.labels, s, name, err, test.labels, nil)
continue
}
// A trailing dot should be ignored.
if !strings.HasSuffix(s, ".") {
dotName, dotErr := ParseName(s + ".")
if dotErr != err || !namesEqual(dotName, name) {
t.Errorf("%+q parsing %+q returned (%+q, %v), expected (%+q, %v)",
test.labels, s+".", dotName, dotErr, name, err)
continue
}
}
}
}
func TestParseName(t *testing.T) {
for _, test := range []struct {
s string
name Name
err error
}{
// This case can't be tested by TestName above because String
// will never produce "" (it produces "." instead).
{"", [][]byte{}, nil},
} {
name, err := ParseName(test.s)
if err != test.err || (err == nil && !namesEqual(name, test.name)) {
t.Errorf("%+q returned (%+q, %v), expected (%+q, %v)",
test.s, name, err, test.name, test.err)
continue
}
}
}
func unescapeString(s string) ([][]byte, error) {
if s == "." {
return [][]byte{}, nil
}
var result [][]byte
for _, label := range strings.Split(s, ".") {
var buf bytes.Buffer
i := 0
for i < len(label) {
switch label[i] {
case '\\':
if i+3 >= len(label) {
return nil, fmt.Errorf("truncated escape sequence at index %v", i)
}
if label[i+1] != 'x' {
return nil, fmt.Errorf("malformed escape sequence at index %v", i)
}
b, err := strconv.ParseUint(string(label[i+2:i+4]), 16, 8)
if err != nil {
return nil, fmt.Errorf("malformed hex sequence at index %v", i+2)
}
buf.WriteByte(byte(b))
i += 4
default:
buf.WriteByte(label[i])
i++
}
}
result = append(result, buf.Bytes())
}
return result, nil
}
func TestNameString(t *testing.T) {
for _, test := range []struct {
name Name
s string
}{
{[][]byte{}, "."},
{[][]byte{[]byte("\x00"), []byte("a.b"), []byte("c\nd\\")}, "\\x00.a\\x2eb.c\\x0ad\\x5c"},
{[][]byte{
[]byte("\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>"),
[]byte("?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}"),
[]byte("~\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc"),
[]byte("\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb"),
[]byte("\xfc\xfd\xfe\xff"),
}, "\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\x09\\x0a\\x0b\\x0c\\x0d\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b\\x1c\\x1d\\x1e\\x1f\\x20\\x21\\x22\\x23\\x24\\x25\\x26\\x27\\x28\\x29\\x2a\\x2b\\x2c-\\x2e\\x2f0123456789\\x3a\\x3b\\x3c\\x3d\\x3e.\\x3f\\x40ABCDEFGHIJKLMNOPQRSTUVWXYZ\\x5b\\x5c\\x5d\\x5e\\x5f\\x60abcdefghijklmnopqrstuvwxyz\\x7b\\x7c\\x7d.\\x7e\\x7f\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89\\x8a\\x8b\\x8c\\x8d\\x8e\\x8f\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97\\x98\\x99\\x9a\\x9b\\x9c\\x9d\\x9e\\x9f\\xa0\\xa1\\xa2\\xa3\\xa4\\xa5\\xa6\\xa7\\xa8\\xa9\\xaa\\xab\\xac\\xad\\xae\\xaf\\xb0\\xb1\\xb2\\xb3\\xb4\\xb5\\xb6\\xb7\\xb8\\xb9\\xba\\xbb\\xbc.\\xbd\\xbe\\xbf\\xc0\\xc1\\xc2\\xc3\\xc4\\xc5\\xc6\\xc7\\xc8\\xc9\\xca\\xcb\\xcc\\xcd\\xce\\xcf\\xd0\\xd1\\xd2\\xd3\\xd4\\xd5\\xd6\\xd7\\xd8\\xd9\\xda\\xdb\\xdc\\xdd\\xde\\xdf\\xe0\\xe1\\xe2\\xe3\\xe4\\xe5\\xe6\\xe7\\xe8\\xe9\\xea\\xeb\\xec\\xed\\xee\\xef\\xf0\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9\\xfa\\xfb.\\xfc\\xfd\\xfe\\xff"},
} {
s := test.name.String()
if s != test.s {
t.Errorf("%+q escaped to %+q, expected %+q", test.name, s, test.s)
continue
}
unescaped, err := unescapeString(s)
if err != nil {
t.Errorf("%+q unescaping %+q resulted in error %v", test.name, s, err)
continue
}
if !namesEqual(Name(unescaped), test.name) {
t.Errorf("%+q roundtripped through %+q to %+q", test.name, s, unescaped)
continue
}
}
}
func TestNameTrimSuffix(t *testing.T) {
for _, test := range []struct {
name, suffix string
trimmed string
ok bool
}{
{"", "", ".", true},
{".", ".", ".", true},
{"abc", "", "abc", true},
{"abc", ".", "abc", true},
{"", "abc", ".", false},
{".", "abc", ".", false},
{"example.com", "com", "example", true},
{"example.com", "net", ".", false},
{"example.com", "example.com", ".", true},
{"example.com", "test.com", ".", false},
{"example.com", "xample.com", ".", false},
{"example.com", "example", ".", false},
{"example.com", "COM", "example", true},
{"EXAMPLE.COM", "com", "EXAMPLE", true},
} {
tmp, ok := mustParseName(test.name).TrimSuffix(mustParseName(test.suffix))
trimmed := tmp.String()
if ok != test.ok || trimmed != test.trimmed {
t.Errorf("TrimSuffix %+q %+q returned (%+q, %v), expected (%+q, %v)",
test.name, test.suffix, trimmed, ok, test.trimmed, test.ok)
continue
}
}
}
func TestReadName(t *testing.T) {
// Good tests.
for _, test := range []struct {
start int64
end int64
input string
s string
}{
// Empty name.
{0, 1, "\x00abcd", "."},
// No pointers.
{12, 25, "AAAABBBBCCCC\x07example\x03com\x00", "example.com"},
// Backward pointer.
{25, 31, "AAAABBBBCCCC\x07example\x03com\x00\x03sub\xc0\x0c", "sub.example.com"},
// Forward pointer.
{0, 4, "\x01a\xc0\x04\x03bcd\x00", "a.bcd"},
// Two backwards pointers.
{31, 38, "AAAABBBBCCCC\x07example\x03com\x00\x03sub\xc0\x0c\x04sub2\xc0\x19", "sub2.sub.example.com"},
// Forward then backward pointer.
{25, 31, "AAAABBBBCCCC\x07example\x03com\x00\x03sub\xc0\x1f\x04sub2\xc0\x0c", "sub.sub2.example.com"},
// Overlapping codons.
{0, 4, "\x01a\xc0\x03bcd\x00", "a.bcd"},
// Pointer to empty label.
{0, 10, "\x07example\xc0\x0a\x00", "example"},
{1, 11, "\x00\x07example\xc0\x00", "example"},
// Pointer to pointer to empty label.
{0, 10, "\x07example\xc0\x0a\xc0\x0c\x00", "example"},
{1, 11, "\x00\x07example\xc0\x0c\xc0\x00", "example"},
} {
r := bytes.NewReader([]byte(test.input))
_, err := r.Seek(test.start, io.SeekStart)
if err != nil {
panic(err)
}
name, err := readName(r)
if err != nil {
t.Errorf("%+q returned error %s", test.input, err)
continue
}
s := name.String()
if s != test.s {
t.Errorf("%+q returned %+q, expected %+q", test.input, s, test.s)
continue
}
cur, _ := r.Seek(0, io.SeekCurrent)
if cur != test.end {
t.Errorf("%+q left offset %d, expected %d", test.input, cur, test.end)
continue
}
}
// Bad tests.
for _, test := range []struct {
start int64
input string
err error
}{
{0, "", io.ErrUnexpectedEOF},
// Reserved label type.
{0, "\x80example", ErrReservedLabelType},
// Reserved label type.
{0, "\x40example", ErrReservedLabelType},
// No Terminating empty label.
{0, "\x07example\x03com", io.ErrUnexpectedEOF},
// Pointer past end of buffer.
{0, "\x07example\xc0\xff", io.ErrUnexpectedEOF},
// Pointer to self.
{0, "\x07example\x03com\xc0\x0c", ErrTooManyPointers},
// Pointer to self with intermediate label.
{0, "\x07example\x03com\xc0\x08", ErrTooManyPointers},
// Two pointers that point to each other.
{0, "\xc0\x02\xc0\x00", ErrTooManyPointers},
// Two pointers that point to each other, with intermediate labels.
{0, "\x01a\xc0\x04\x01b\xc0\x00", ErrTooManyPointers},
// EOF while reading label.
{0, "\x0aexample", io.ErrUnexpectedEOF},
// EOF before second byte of pointer.
{0, "\xc0", io.ErrUnexpectedEOF},
{0, "\x07example\xc0", io.ErrUnexpectedEOF},
} {
r := bytes.NewReader([]byte(test.input))
_, err := r.Seek(test.start, io.SeekStart)
if err != nil {
panic(err)
}
name, err := readName(r)
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
if err != test.err {
t.Errorf("%+q returned (%+q, %v), expected %v", test.input, name, err, test.err)
continue
}
}
}
func mustParseName(s string) Name {
name, err := ParseName(s)
if err != nil {
panic(err)
}
return name
}
func questionsEqual(a, b *Question) bool {
if !namesEqual(a.Name, b.Name) {
return false
}
if a.Type != b.Type || a.Class != b.Class {
return false
}
return true
}
func rrsEqual(a, b *RR) bool {
if !namesEqual(a.Name, b.Name) {
return false
}
if a.Type != b.Type || a.Class != b.Class || a.TTL != b.TTL {
return false
}
if !bytes.Equal(a.Data, b.Data) {
return false
}
return true
}
func messagesEqual(a, b *Message) bool {
if a.ID != b.ID || a.Flags != b.Flags {
return false
}
if len(a.Question) != len(b.Question) {
return false
}
for i := 0; i < len(a.Question); i++ {
if !questionsEqual(&a.Question[i], &b.Question[i]) {
return false
}
}
for _, rec := range []struct{ rrA, rrB []RR }{
{a.Answer, b.Answer},
{a.Authority, b.Authority},
{a.Additional, b.Additional},
} {
if len(rec.rrA) != len(rec.rrB) {
return false
}
for i := 0; i < len(rec.rrA); i++ {
if !rrsEqual(&rec.rrA[i], &rec.rrB[i]) {
return false
}
}
}
return true
}
func TestMessageFromWireFormat(t *testing.T) {
for _, test := range []struct {
buf string
expected Message
err error
}{
{
"\x12\x34",
Message{},
io.ErrUnexpectedEOF,
},
{
"\x12\x34\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03www\x07example\x03com\x00\x00\x01\x00\x01",
Message{
ID: 0x1234,
Flags: 0x0100,
Question: []Question{
{
Name: mustParseName("www.example.com"),
Type: 1,
Class: 1,
},
},
Answer: []RR{},
Authority: []RR{},
Additional: []RR{},
},
nil,
},
{
"\x12\x34\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03www\x07example\x03com\x00\x00\x01\x00\x01X",
Message{},
ErrTrailingBytes,
},
{
"\x12\x34\x81\x80\x00\x01\x00\x01\x00\x00\x00\x00\x03www\x07example\x03com\x00\x00\x01\x00\x01\x03www\x07example\x03com\x00\x00\x01\x00\x01\x00\x00\x00\x80\x00\x04\xc0\x00\x02\x01",
Message{
ID: 0x1234,
Flags: 0x8180,
Question: []Question{
{
Name: mustParseName("www.example.com"),
Type: 1,
Class: 1,
},
},
Answer: []RR{
{
Name: mustParseName("www.example.com"),
Type: 1,
Class: 1,
TTL: 128,
Data: []byte{192, 0, 2, 1},
},
},
Authority: []RR{},
Additional: []RR{},
},
nil,
},
} {
message, err := MessageFromWireFormat([]byte(test.buf))
if err != test.err || (err == nil && !messagesEqual(&message, &test.expected)) {
t.Errorf("%+q\nreturned (%+v, %v)\nexpected (%+v, %v)",
test.buf, message, err, test.expected, test.err)
continue
}
}
}
func TestMessageWireFormatRoundTrip(t *testing.T) {
for _, message := range []Message{
{
ID: 0x1234,
Flags: 0x0100,
Question: []Question{
{
Name: mustParseName("www.example.com"),
Type: 1,
Class: 1,
},
{
Name: mustParseName("www2.example.com"),
Type: 2,
Class: 2,
},
},
Answer: []RR{
{
Name: mustParseName("abc"),
Type: 2,
Class: 3,
TTL: 0xffffffff,
Data: []byte{1},
},
{
Name: mustParseName("xyz"),
Type: 2,
Class: 3,
TTL: 255,
Data: []byte{},
},
},
Authority: []RR{
{
Name: mustParseName("."),
Type: 65535,
Class: 65535,
TTL: 0,
Data: []byte("XXXXXXXXXXXXXXXXXXX"),
},
},
Additional: []RR{},
},
} {
buf, err := message.WireFormat()
if err != nil {
t.Errorf("%+v cannot make wire format: %v", message, err)
continue
}
message2, err := MessageFromWireFormat(buf)
if err != nil {
t.Errorf("%+q cannot parse wire format: %v", buf, err)
continue
}
if !messagesEqual(&message, &message2) {
t.Errorf("messages unequal\nbefore: %+v\n after: %+v", message, message2)
continue
}
}
}
func TestDecodeRDataTXT(t *testing.T) {
for _, test := range []struct {
p []byte
decoded []byte
err error
}{
{[]byte{}, nil, io.ErrUnexpectedEOF},
{[]byte("\x00"), []byte{}, nil},
{[]byte("\x01"), nil, io.ErrUnexpectedEOF},
} {
decoded, err := DecodeRDataTXT(test.p)
if err != test.err || (err == nil && !bytes.Equal(decoded, test.decoded)) {
t.Errorf("%+q\nreturned (%+q, %v)\nexpected (%+q, %v)",
test.p, decoded, err, test.decoded, test.err)
continue
}
}
}
func TestEncodeRDataTXT(t *testing.T) {
// Encoding 0 bytes needs to return at least a single length octet of
// zero, not an empty slice.
p := make([]byte, 0)
encoded := EncodeRDataTXT(p)
if len(encoded) < 0 {
t.Errorf("EncodeRDataTXT(%v) returned %v", p, encoded)
}
// 255 bytes should be able to be encoded into 256 bytes.
p = make([]byte, 255)
encoded = EncodeRDataTXT(p)
if len(encoded) > 256 {
t.Errorf("EncodeRDataTXT(%d bytes) returned %d bytes", len(p), len(encoded))
}
}
func TestRDataTXTRoundTrip(t *testing.T) {
for _, p := range [][]byte{
{},
[]byte("\x00"),
{
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f,
0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f,
0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f,
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f,
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f,
0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f,
0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf,
0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf,
0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf,
0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef,
0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff,
},
} {
rdata := EncodeRDataTXT(p)
decoded, err := DecodeRDataTXT(rdata)
if err != nil || !bytes.Equal(decoded, p) {
t.Errorf("%+q returned (%+q, %v)", p, decoded, err)
continue
}
}
}

View File

@@ -0,0 +1,567 @@
package xdns
import (
"bytes"
"context"
"encoding/binary"
"io"
"net"
"sync"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
)
const (
idleTimeout = 2 * time.Minute
responseTTL = 60
maxResponseDelay = 1 * time.Second
)
var (
maxUDPPayload = 1280 - 40 - 8
maxEncodedPayload = computeMaxEncodedPayload(maxUDPPayload)
)
func clientIDToAddr(clientID [8]byte) *net.UDPAddr {
ip := make(net.IP, 16)
copy(ip, []byte{0xfd, 0x00, 0, 0, 0, 0, 0, 0})
copy(ip[8:], clientID[:])
return &net.UDPAddr{
IP: ip,
}
}
type record struct {
Resp *Message
Addr net.Addr
// ClientID [8]byte
ClientAddr net.Addr
}
type queue struct {
lash time.Time
queue chan []byte
stash chan []byte
}
type xdnsConnServer struct {
conn net.PacketConn
domain Name
ch chan *record
readQueue chan *packet
writeQueueMap map[string]*queue
closed bool
mutex sync.Mutex
}
func NewConnServer(c *Config, raw net.PacketConn, end bool) (net.PacketConn, error) {
if !end {
return nil, errors.New("xdns requires being at the outermost level")
}
domain, err := ParseName(c.Domain)
if err != nil {
return nil, err
}
conn := &xdnsConnServer{
conn: raw,
domain: domain,
ch: make(chan *record, 100),
readQueue: make(chan *packet, 128),
writeQueueMap: make(map[string]*queue),
}
go conn.clean()
go conn.recvLoop()
go conn.sendLoop()
return conn, nil
}
func (c *xdnsConnServer) clean() {
f := func() bool {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed {
return true
}
now := time.Now()
for key, q := range c.writeQueueMap {
if now.Sub(q.lash) >= idleTimeout {
close(q.queue)
close(q.stash)
delete(c.writeQueueMap, key)
}
}
return false
}
for {
time.Sleep(idleTimeout / 2)
if f() {
return
}
}
}
func (c *xdnsConnServer) ensureQueue(addr net.Addr) *queue {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed {
return nil
}
q, ok := c.writeQueueMap[addr.String()]
if !ok {
q = &queue{
queue: make(chan []byte, 128),
stash: make(chan []byte, 1),
}
c.writeQueueMap[addr.String()] = q
}
q.lash = time.Now()
return q
}
func (c *xdnsConnServer) stash(queue *queue, p []byte) {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed {
return
}
select {
case queue.stash <- p:
default:
}
}
func (c *xdnsConnServer) recvLoop() {
for {
if c.closed {
break
}
var buf [4096]byte
n, addr, err := c.conn.ReadFrom(buf[:])
if err != nil {
continue
}
query, err := MessageFromWireFormat(buf[:n])
if err != nil {
continue
}
resp, payload := responseFor(&query, c.domain)
var clientID [8]byte
n = copy(clientID[:], payload)
payload = payload[n:]
if n == len(clientID) {
r := bytes.NewReader(payload)
for {
p, err := nextPacketServer(r)
if err != nil {
break
}
buf := make([]byte, len(p))
copy(buf, p)
select {
case c.readQueue <- &packet{
p: buf,
addr: clientIDToAddr(clientID),
}:
default:
}
}
} else {
if resp != nil && resp.Rcode() == RcodeNoError {
resp.Flags |= RcodeNameError
}
}
if resp != nil {
select {
case c.ch <- &record{resp, addr, clientIDToAddr(clientID)}:
default:
}
}
}
close(c.ch)
close(c.readQueue)
}
func (c *xdnsConnServer) sendLoop() {
var nextRec *record
for {
rec := nextRec
nextRec = nil
if rec == nil {
var ok bool
rec, ok = <-c.ch
if !ok {
break
}
}
if rec.Resp.Rcode() == RcodeNoError && len(rec.Resp.Question) == 1 {
rec.Resp.Answer = []RR{
{
Name: rec.Resp.Question[0].Name,
Type: rec.Resp.Question[0].Type,
Class: rec.Resp.Question[0].Class,
TTL: responseTTL,
Data: nil,
},
}
var payload bytes.Buffer
limit := maxEncodedPayload
timer := time.NewTimer(maxResponseDelay)
for {
queue := c.ensureQueue(rec.ClientAddr)
if queue == nil {
return
}
var p []byte
select {
case p = <-queue.stash:
default:
select {
case p = <-queue.stash:
case p = <-queue.queue:
default:
select {
case p = <-queue.stash:
case p = <-queue.queue:
case <-timer.C:
case nextRec = <-c.ch:
}
}
}
timer.Reset(0)
if len(p) == 0 {
break
}
limit -= 2 + len(p)
if payload.Len() == 0 {
} else if limit < 0 {
c.stash(queue, p)
break
}
if int(uint16(len(p))) != len(p) {
panic(len(p))
}
_ = binary.Write(&payload, binary.BigEndian, uint16(len(p)))
payload.Write(p)
}
timer.Stop()
rec.Resp.Answer[0].Data = EncodeRDataTXT(payload.Bytes())
}
buf, err := rec.Resp.WireFormat()
if err != nil {
continue
}
if len(buf) > maxUDPPayload {
errors.LogDebug(context.Background(), "xdns server truncate ", len(buf))
buf = buf[:maxUDPPayload]
buf[2] |= 0x02
}
if c.closed {
return
}
_, _ = c.conn.WriteTo(buf, rec.Addr)
}
}
func (c *xdnsConnServer) Size() int32 {
return 0
}
func (c *xdnsConnServer) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
packet, ok := <-c.readQueue
if !ok {
return 0, nil, io.EOF
}
n = copy(p, packet.p)
if n != len(packet.p) {
return 0, nil, io.ErrShortBuffer
}
return n, packet.addr, nil
}
func (c *xdnsConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) {
q := c.ensureQueue(addr)
if q == nil {
return 0, errors.New("xdns closed")
}
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed {
return 0, errors.New("xdns closed")
}
buf := make([]byte, len(p))
copy(buf, p)
select {
case q.queue <- buf:
return len(p), nil
default:
return 0, errors.New("xdns queue full")
}
}
func (c *xdnsConnServer) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed {
return nil
}
c.closed = true
for key, q := range c.writeQueueMap {
close(q.queue)
close(q.stash)
delete(c.writeQueueMap, key)
}
return c.conn.Close()
}
func (c *xdnsConnServer) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *xdnsConnServer) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
func (c *xdnsConnServer) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *xdnsConnServer) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
func nextPacketServer(r *bytes.Reader) ([]byte, error) {
eof := func(err error) error {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return err
}
for {
prefix, err := r.ReadByte()
if err != nil {
return nil, err
}
if prefix >= 224 {
paddingLen := prefix - 224
_, err := io.CopyN(io.Discard, r, int64(paddingLen))
if err != nil {
return nil, eof(err)
}
} else {
p := make([]byte, int(prefix))
_, err = io.ReadFull(r, p)
return p, eof(err)
}
}
}
func responseFor(query *Message, domain Name) (*Message, []byte) {
resp := &Message{
ID: query.ID,
Flags: 0x8000,
Question: query.Question,
}
if query.Flags&0x8000 != 0 {
return nil, nil
}
payloadSize := 0
for _, rr := range query.Additional {
if rr.Type != RRTypeOPT {
continue
}
if len(resp.Additional) != 0 {
resp.Flags |= RcodeFormatError
return resp, nil
}
resp.Additional = append(resp.Additional, RR{
Name: Name{},
Type: RRTypeOPT,
Class: 4096,
TTL: 0,
Data: []byte{},
})
additional := &resp.Additional[0]
version := (rr.TTL >> 16) & 0xff
if version != 0 {
resp.Flags |= ExtendedRcodeBadVers & 0xf
additional.TTL = (ExtendedRcodeBadVers >> 4) << 24
return resp, nil
}
payloadSize = int(rr.Class)
}
if payloadSize < 512 {
payloadSize = 512
}
if len(query.Question) != 1 {
resp.Flags |= RcodeFormatError
return resp, nil
}
question := query.Question[0]
prefix, ok := question.Name.TrimSuffix(domain)
if !ok {
resp.Flags |= RcodeNameError
return resp, nil
}
resp.Flags |= 0x0400
if query.Opcode() != 0 {
resp.Flags |= RcodeNotImplemented
return resp, nil
}
if question.Type != RRTypeTXT {
resp.Flags |= RcodeNameError
return resp, nil
}
encoded := bytes.ToUpper(bytes.Join(prefix, nil))
payload := make([]byte, base32Encoding.DecodedLen(len(encoded)))
n, err := base32Encoding.Decode(payload, encoded)
if err != nil {
resp.Flags |= RcodeNameError
return resp, nil
}
payload = payload[:n]
if payloadSize < maxUDPPayload {
resp.Flags |= RcodeFormatError
return resp, nil
}
return resp, payload
}
func computeMaxEncodedPayload(limit int) int {
maxLengthName, err := NewName([][]byte{
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
[]byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
})
if err != nil {
panic(err)
}
{
n := 0
for _, label := range maxLengthName {
n += len(label) + 1
}
n += 1
if n != 255 {
panic("computeMaxEncodedPayload n != 255")
}
}
queryLimit := uint16(limit)
if int(queryLimit) != limit {
queryLimit = 0xffff
}
query := &Message{
Question: []Question{
{
Name: maxLengthName,
Type: RRTypeTXT,
Class: RRTypeTXT,
},
},
Additional: []RR{
{
Name: Name{},
Type: RRTypeOPT,
Class: queryLimit,
TTL: 0,
Data: []byte{},
},
},
}
resp, _ := responseFor(query, [][]byte{})
resp.Answer = []RR{
{
Name: query.Question[0].Name,
Type: query.Question[0].Type,
Class: query.Question[0].Class,
TTL: responseTTL,
Data: nil,
},
}
low := 0
high := 32768
for low+1 < high {
mid := (low + high) / 2
resp.Answer[0].Data = EncodeRDataTXT(make([]byte, mid))
buf, err := resp.WireFormat()
if err != nil {
panic(err)
}
if len(buf) <= limit {
low = mid
} else {
high = mid
}
}
return low
}

View File

@@ -0,0 +1,350 @@
package xicmp
import (
"context"
"io"
"net"
"strings"
"sync"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/common/crypto"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
const (
initPollDelay = 500 * time.Millisecond
maxPollDelay = 10 * time.Second
pollDelayMultiplier = 2.0
pollLimit = 16
windowSize = 1000
)
type packet struct {
p []byte
addr net.Addr
}
type seqStatus struct {
needSeqByte bool
seqByte byte
}
type xicmpConnClient struct {
conn net.PacketConn
icmpConn *icmp.PacketConn
typ icmp.Type
id int
seq int
proto int
seqStatus map[int]*seqStatus
pollChan chan struct{}
readQueue chan *packet
writeQueue chan *packet
closed bool
mutex sync.Mutex
}
func NewConnClient(c *Config, raw net.PacketConn, end bool) (net.PacketConn, error) {
if !end {
return nil, errors.New("xicmp requires being at the outermost level")
}
network := "ip4:icmp"
typ := icmp.Type(ipv4.ICMPTypeEcho)
proto := 1
if strings.Contains(c.Ip, ":") {
network = "ip6:ipv6-icmp"
typ = ipv6.ICMPTypeEchoRequest
proto = 58
}
icmpConn, err := icmp.ListenPacket(network, c.Ip)
if err != nil {
return nil, errors.New("xicmp listen err").Base(err)
}
if c.Id == 0 {
c.Id = int32(crypto.RandBetween(0, 65535))
}
conn := &xicmpConnClient{
conn: raw,
icmpConn: icmpConn,
typ: typ,
id: int(c.Id),
seq: 1,
proto: proto,
seqStatus: make(map[int]*seqStatus),
pollChan: make(chan struct{}, pollLimit),
readQueue: make(chan *packet, 128),
writeQueue: make(chan *packet, 128),
}
go conn.recvLoop()
go conn.sendLoop()
return conn, nil
}
func (c *xicmpConnClient) encode(p []byte) ([]byte, error) {
c.mutex.Lock()
defer c.mutex.Unlock()
needSeqByte := false
var seqByte byte
data := p
if len(p) > 0 {
needSeqByte = true
seqByte = p[0]
}
msg := icmp.Message{
Type: c.typ,
Code: 0,
Body: &icmp.Echo{
ID: c.id,
Seq: c.seq,
Data: data,
},
}
buf, err := msg.Marshal(nil)
if err != nil {
return nil, err
}
if len(buf) > 8192 {
return nil, errors.New("xicmp len(buf) > 8192")
}
c.seqStatus[c.seq] = &seqStatus{
needSeqByte: needSeqByte,
seqByte: seqByte,
}
delete(c.seqStatus, int(uint16(c.seq-windowSize)))
c.seq++
if c.seq == 65536 {
delete(c.seqStatus, int(uint16(c.seq-windowSize)))
c.seq = 1
}
return buf, nil
}
func (c *xicmpConnClient) recvLoop() {
for {
if c.closed {
break
}
var buf [8192]byte
n, addr, err := c.icmpConn.ReadFrom(buf[:])
if err != nil {
continue
}
msg, err := icmp.ParseMessage(c.proto, buf[:n])
if err != nil {
continue
}
if msg.Type != ipv4.ICMPTypeEchoReply && msg.Type != ipv6.ICMPTypeEchoReply {
continue
}
echo, ok := msg.Body.(*icmp.Echo)
if !ok {
continue
}
c.mutex.Lock()
seqStatus, ok := c.seqStatus[echo.Seq]
c.mutex.Unlock()
if !ok {
continue
}
if seqStatus.needSeqByte {
if len(echo.Data) <= 1 {
continue
}
if echo.Data[0] == seqStatus.seqByte {
continue
}
echo.Data = echo.Data[1:]
}
if len(echo.Data) > 0 {
c.mutex.Lock()
delete(c.seqStatus, echo.Seq)
c.mutex.Unlock()
buf := make([]byte, len(echo.Data))
copy(buf, echo.Data)
select {
case c.readQueue <- &packet{
p: buf,
addr: &net.UDPAddr{IP: addr.(*net.IPAddr).IP},
}:
default:
}
select {
case c.pollChan <- struct{}{}:
default:
}
}
}
close(c.pollChan)
close(c.readQueue)
}
func (c *xicmpConnClient) sendLoop() {
var addr net.Addr
pollDelay := initPollDelay
pollTimer := time.NewTimer(pollDelay)
for {
var p *packet
pollTimerExpired := false
select {
case p = <-c.writeQueue:
default:
select {
case p = <-c.writeQueue:
case <-c.pollChan:
case <-pollTimer.C:
pollTimerExpired = true
}
}
if p != nil {
addr = p.addr
select {
case <-c.pollChan:
default:
}
} else if addr != nil {
encoded, _ := c.encode(nil)
p = &packet{
p: encoded,
addr: addr,
}
}
if pollTimerExpired {
pollDelay = time.Duration(float64(pollDelay) * pollDelayMultiplier)
if pollDelay > maxPollDelay {
pollDelay = maxPollDelay
}
} else {
if !pollTimer.Stop() {
<-pollTimer.C
}
pollDelay = initPollDelay
}
pollTimer.Reset(pollDelay)
if c.closed {
return
}
if p != nil {
_, err := c.icmpConn.WriteTo(p.p, p.addr)
if err != nil {
errors.LogDebug(context.Background(), "xicmp writeto err ", err)
}
}
}
}
func (c *xicmpConnClient) Size() int32 {
return 0
}
func (c *xicmpConnClient) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
packet, ok := <-c.readQueue
if !ok {
return 0, nil, io.EOF
}
n = copy(p, packet.p)
if n != len(packet.p) {
return 0, nil, io.ErrShortBuffer
}
return n, packet.addr, nil
}
func (c *xicmpConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) {
encoded, err := c.encode(p)
if err != nil {
return 0, errors.New("xicmp encode").Base(err)
}
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed {
return 0, errors.New("xicmp closed")
}
select {
case c.writeQueue <- &packet{
p: encoded,
addr: &net.IPAddr{IP: addr.(*net.UDPAddr).IP},
}:
return len(p), nil
default:
return 0, errors.New("xicmp queue full")
}
}
func (c *xicmpConnClient) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed {
return nil
}
c.closed = true
close(c.writeQueue)
_ = c.icmpConn.Close()
return c.conn.Close()
}
func (c *xicmpConnClient) LocalAddr() net.Addr {
return &net.UDPAddr{
IP: c.icmpConn.LocalAddr().(*net.IPAddr).IP,
Port: c.id,
}
}
func (c *xicmpConnClient) SetDeadline(t time.Time) error {
return c.icmpConn.SetDeadline(t)
}
func (c *xicmpConnClient) SetReadDeadline(t time.Time) error {
return c.icmpConn.SetReadDeadline(t)
}
func (c *xicmpConnClient) SetWriteDeadline(t time.Time) error {
return c.icmpConn.SetWriteDeadline(t)
}

View File

@@ -0,0 +1,16 @@
package xicmp
import (
"net"
)
func (c *Config) UDP() {
}
func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnClient(c, raw, end)
}
func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) {
return NewConnServer(c, raw, end)
}

View File

@@ -0,0 +1,132 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.11
// protoc v3.21.12
// source: transport/internet/finalmask/xicmp/config.proto
package xicmp
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type Config struct {
state protoimpl.MessageState `protogen:"open.v1"`
Ip string `protobuf:"bytes,1,opt,name=ip,proto3" json:"ip,omitempty"`
Id int32 `protobuf:"varint,2,opt,name=id,proto3" json:"id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Config) Reset() {
*x = Config{}
mi := &file_transport_internet_finalmask_xicmp_config_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Config) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Config) ProtoMessage() {}
func (x *Config) ProtoReflect() protoreflect.Message {
mi := &file_transport_internet_finalmask_xicmp_config_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Config.ProtoReflect.Descriptor instead.
func (*Config) Descriptor() ([]byte, []int) {
return file_transport_internet_finalmask_xicmp_config_proto_rawDescGZIP(), []int{0}
}
func (x *Config) GetIp() string {
if x != nil {
return x.Ip
}
return ""
}
func (x *Config) GetId() int32 {
if x != nil {
return x.Id
}
return 0
}
var File_transport_internet_finalmask_xicmp_config_proto protoreflect.FileDescriptor
const file_transport_internet_finalmask_xicmp_config_proto_rawDesc = "" +
"\n" +
"/transport/internet/finalmask/xicmp/config.proto\x12'xray.transport.internet.finalmask.xicmp\"(\n" +
"\x06Config\x12\x0e\n" +
"\x02ip\x18\x01 \x01(\tR\x02ip\x12\x0e\n" +
"\x02id\x18\x02 \x01(\x05R\x02idB\xa6\x01\n" +
"+com.xray.transport.internet.finalmask.xicmpP\x01ZKgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/xicmp\xaa\x02'Xray.Transport.Internet.Finalmask.Xicmpb\x06proto3"
var (
file_transport_internet_finalmask_xicmp_config_proto_rawDescOnce sync.Once
file_transport_internet_finalmask_xicmp_config_proto_rawDescData []byte
)
func file_transport_internet_finalmask_xicmp_config_proto_rawDescGZIP() []byte {
file_transport_internet_finalmask_xicmp_config_proto_rawDescOnce.Do(func() {
file_transport_internet_finalmask_xicmp_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_xicmp_config_proto_rawDesc), len(file_transport_internet_finalmask_xicmp_config_proto_rawDesc)))
})
return file_transport_internet_finalmask_xicmp_config_proto_rawDescData
}
var file_transport_internet_finalmask_xicmp_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_transport_internet_finalmask_xicmp_config_proto_goTypes = []any{
(*Config)(nil), // 0: xray.transport.internet.finalmask.xicmp.Config
}
var file_transport_internet_finalmask_xicmp_config_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_transport_internet_finalmask_xicmp_config_proto_init() }
func file_transport_internet_finalmask_xicmp_config_proto_init() {
if File_transport_internet_finalmask_xicmp_config_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_xicmp_config_proto_rawDesc), len(file_transport_internet_finalmask_xicmp_config_proto_rawDesc)),
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_transport_internet_finalmask_xicmp_config_proto_goTypes,
DependencyIndexes: file_transport_internet_finalmask_xicmp_config_proto_depIdxs,
MessageInfos: file_transport_internet_finalmask_xicmp_config_proto_msgTypes,
}.Build()
File_transport_internet_finalmask_xicmp_config_proto = out.File
file_transport_internet_finalmask_xicmp_config_proto_goTypes = nil
file_transport_internet_finalmask_xicmp_config_proto_depIdxs = nil
}

View File

@@ -0,0 +1,13 @@
syntax = "proto3";
package xray.transport.internet.finalmask.xicmp;
option csharp_namespace = "Xray.Transport.Internet.Finalmask.Xicmp";
option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/xicmp";
option java_package = "com.xray.transport.internet.finalmask.xicmp";
option java_multiple_files = true;
message Config {
string ip = 1;
int32 id = 2;
}

View File

@@ -0,0 +1,377 @@
package xicmp
import (
"context"
"io"
"net"
"strings"
"sync"
"time"
"github.com/amnezia-vpn/amnezia-xray-core/common/crypto"
"github.com/amnezia-vpn/amnezia-xray-core/common/errors"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
const (
idleTimeout = 2 * time.Minute
maxResponseDelay = 1 * time.Second
)
type record struct {
id int
seq int
needSeqByte bool
seqByte byte
addr net.Addr
}
type queue struct {
lash time.Time
queue chan []byte
}
type xicmpConnServer struct {
conn net.PacketConn
icmpConn *icmp.PacketConn
typ icmp.Type
proto int
config *Config
ch chan *record
readQueue chan *packet
writeQueueMap map[string]*queue
closed bool
mutex sync.Mutex
}
func NewConnServer(c *Config, raw net.PacketConn, end bool) (net.PacketConn, error) {
if !end {
return nil, errors.New("xicmp requires being at the outermost level")
}
network := "ip4:icmp"
typ := icmp.Type(ipv4.ICMPTypeEchoReply)
proto := 1
if strings.Contains(c.Ip, ":") {
network = "ip6:ipv6-icmp"
typ = ipv6.ICMPTypeEchoReply
proto = 58
}
icmpConn, err := icmp.ListenPacket(network, c.Ip)
if err != nil {
return nil, errors.New("xicmp listen err").Base(err)
}
conn := &xicmpConnServer{
conn: raw,
icmpConn: icmpConn,
typ: typ,
proto: proto,
config: c,
ch: make(chan *record, 100),
readQueue: make(chan *packet, 128),
writeQueueMap: make(map[string]*queue),
}
go conn.clean()
go conn.recvLoop()
go conn.sendLoop()
return conn, nil
}
func (c *xicmpConnServer) clean() {
f := func() bool {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed {
return true
}
now := time.Now()
for key, q := range c.writeQueueMap {
if now.Sub(q.lash) >= idleTimeout {
close(q.queue)
delete(c.writeQueueMap, key)
}
}
return false
}
for {
time.Sleep(idleTimeout / 2)
if f() {
return
}
}
}
func (c *xicmpConnServer) ensureQueue(addr net.Addr) *queue {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed {
return nil
}
q, ok := c.writeQueueMap[addr.String()]
if !ok {
q = &queue{
queue: make(chan []byte, 128),
}
c.writeQueueMap[addr.String()] = q
}
q.lash = time.Now()
return q
}
func (c *xicmpConnServer) encode(p []byte, id int, seq int, needSeqByte bool, seqByte byte) ([]byte, error) {
data := p
if needSeqByte {
b2 := c.randUntil(seqByte)
data = append([]byte{b2}, p...)
}
msg := icmp.Message{
Type: c.typ,
Code: 0,
Body: &icmp.Echo{
ID: id,
Seq: seq,
Data: data,
},
}
buf, err := msg.Marshal(nil)
if err != nil {
return nil, err
}
if len(buf) > 8192 {
return nil, errors.New("xicmp len(buf) > 8192")
}
return buf, nil
}
func (c *xicmpConnServer) randUntil(b1 byte) byte {
b2 := byte(crypto.RandBetween(0, 255))
for {
if b2 != b1 {
return b2
}
b2 = byte(crypto.RandBetween(0, 255))
}
}
func (c *xicmpConnServer) recvLoop() {
for {
if c.closed {
break
}
var buf [8192]byte
n, addr, err := c.icmpConn.ReadFrom(buf[:])
if err != nil {
continue
}
msg, err := icmp.ParseMessage(c.proto, buf[:n])
if err != nil {
continue
}
if msg.Type != ipv4.ICMPTypeEcho && msg.Type != ipv6.ICMPTypeEchoRequest {
continue
}
echo, ok := msg.Body.(*icmp.Echo)
if !ok {
continue
}
if c.config.Id != 0 && echo.ID != int(c.config.Id) {
continue
}
needSeqByte := false
var seqByte byte
if len(echo.Data) > 0 {
needSeqByte = true
seqByte = echo.Data[0]
buf := make([]byte, len(echo.Data))
copy(buf, echo.Data)
select {
case c.readQueue <- &packet{
p: buf,
addr: &net.UDPAddr{
IP: addr.(*net.IPAddr).IP,
Port: echo.ID,
},
}:
default:
}
}
select {
case c.ch <- &record{
id: echo.ID,
seq: echo.Seq,
needSeqByte: needSeqByte,
seqByte: seqByte,
addr: &net.UDPAddr{
IP: addr.(*net.IPAddr).IP,
Port: echo.ID,
},
}:
default:
}
}
close(c.ch)
close(c.readQueue)
}
func (c *xicmpConnServer) sendLoop() {
var nextRec *record
for {
rec := nextRec
nextRec = nil
if rec == nil {
var ok bool
rec, ok = <-c.ch
if !ok {
break
}
}
queue := c.ensureQueue(rec.addr)
if queue == nil {
return
}
var p []byte
timer := time.NewTimer(maxResponseDelay)
select {
case p = <-queue.queue:
default:
select {
case p = <-queue.queue:
case <-timer.C:
case nextRec = <-c.ch:
}
}
timer.Stop()
if len(p) == 0 {
continue
}
buf, err := c.encode(p, rec.id, rec.seq, rec.needSeqByte, rec.seqByte)
if err != nil {
continue
}
if c.closed {
return
}
_, err = c.icmpConn.WriteTo(buf, &net.IPAddr{IP: rec.addr.(*net.UDPAddr).IP})
if err != nil {
errors.LogDebug(context.Background(), "xicmp writeto err ", err)
}
}
}
func (c *xicmpConnServer) Size() int32 {
return 0
}
func (c *xicmpConnServer) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
packet, ok := <-c.readQueue
if !ok {
return 0, nil, io.EOF
}
n = copy(p, packet.p)
if n != len(packet.p) {
return 0, nil, io.ErrShortBuffer
}
return n, packet.addr, nil
}
func (c *xicmpConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) {
q := c.ensureQueue(addr)
if q == nil {
return 0, errors.New("xicmp closed")
}
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed {
return 0, errors.New("xicmp closed")
}
buf := make([]byte, len(p))
copy(buf, p)
select {
case q.queue <- buf:
return len(p), nil
default:
return 0, errors.New("xicmp queue full")
}
}
func (c *xicmpConnServer) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.closed {
return nil
}
c.closed = true
for key, q := range c.writeQueueMap {
close(q.queue)
delete(c.writeQueueMap, key)
}
_ = c.icmpConn.Close()
return c.conn.Close()
}
func (c *xicmpConnServer) LocalAddr() net.Addr {
return &net.UDPAddr{IP: c.icmpConn.LocalAddr().(*net.IPAddr).IP}
}
func (c *xicmpConnServer) SetDeadline(t time.Time) error {
return c.icmpConn.SetDeadline(t)
}
func (c *xicmpConnServer) SetReadDeadline(t time.Time) error {
return c.icmpConn.SetReadDeadline(t)
}
func (c *xicmpConnServer) SetWriteDeadline(t time.Time) error {
return c.icmpConn.SetWriteDeadline(t)
}

View File

@@ -0,0 +1,74 @@
package xicmp_test
import (
"bytes"
"fmt"
"testing"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
func TestICMPEchoMarshal(t *testing.T) {
msg := icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &icmp.Echo{
ID: 65535,
Seq: 65537,
Data: nil,
},
}
ICMPTypeEcho, _ := msg.Marshal(nil)
fmt.Println("ICMPTypeEcho", len(ICMPTypeEcho), ICMPTypeEcho)
msg = icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
Code: 0,
Body: &icmp.Echo{
ID: 65535,
Seq: 65537,
Data: nil,
},
}
ICMPTypeEchoReply, _ := msg.Marshal(nil)
fmt.Println("ICMPTypeEchoReply", len(ICMPTypeEchoReply), ICMPTypeEchoReply)
msg = icmp.Message{
Type: ipv6.ICMPTypeEchoRequest,
Code: 0,
Body: &icmp.Echo{
ID: 65535,
Seq: 65537,
Data: nil,
},
}
ICMPTypeEchoRequest, _ := msg.Marshal(nil)
fmt.Println("ICMPTypeEchoRequest", len(ICMPTypeEchoRequest), ICMPTypeEchoRequest)
msg = icmp.Message{
Type: ipv6.ICMPTypeEchoReply,
Code: 0,
Body: &icmp.Echo{
ID: 65535,
Seq: 65537,
Data: nil,
},
}
V6ICMPTypeEchoReply, _ := msg.Marshal(nil)
fmt.Println("V6ICMPTypeEchoReply", len(V6ICMPTypeEchoReply), V6ICMPTypeEchoReply)
if !bytes.Equal(ICMPTypeEcho[0:2], []byte{8, 0}) || !bytes.Equal(ICMPTypeEcho[4:], []byte{255, 255, 0, 1}) {
t.Fatalf("ICMPTypeEcho Type/Code or ID/Seq mismatch: %v", ICMPTypeEcho)
}
if !bytes.Equal(ICMPTypeEchoReply[0:2], []byte{0, 0}) || !bytes.Equal(ICMPTypeEchoReply[4:], []byte{255, 255, 0, 1}) {
t.Fatalf("ICMPTypeEchoReply Type/Code or ID/Seq mismatch: %v", ICMPTypeEchoReply)
}
if !bytes.Equal(ICMPTypeEchoRequest[0:2], []byte{128, 0}) || !bytes.Equal(ICMPTypeEchoRequest[4:], []byte{255, 255, 0, 1}) {
t.Fatalf("ICMPTypeEchoRequest Type/Code or ID/Seq mismatch: %v", ICMPTypeEchoRequest)
}
if !bytes.Equal(V6ICMPTypeEchoReply[0:2], []byte{129, 0}) || !bytes.Equal(V6ICMPTypeEchoReply[4:], []byte{255, 255, 0, 1}) {
t.Fatalf("V6ICMPTypeEchoReply Type/Code or ID/Seq mismatch: %v", V6ICMPTypeEchoReply)
}
}

Some files were not shown because too many files have changed in this diff Show More