diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 036b7b9c..c07983a4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -34,6 +34,22 @@ jobs: if: steps.check-assets.outputs.missing == 'true' run: sleep 90 + check-proto: + runs-on: ubuntu-latest + steps: + - name: Checkout codebase + uses: actions/checkout@v6 + - name: Check Proto Version Header + run: | + head -n 4 core/config.pb.go > ref.txt + find . -name "*.pb.go" ! -name "*_grpc.pb.go" -print0 | while IFS= read -r -d '' file; do + if ! cmp -s ref.txt <(head -n 4 "$file"); then + echo "Error: Header mismatch in $file" + head -n 4 "$file" + exit 1 + fi + done + test: needs: check-assets permissions: diff --git a/app/dns/dns.go b/app/dns/dns.go index 91eac480..e47ae54d 100644 --- a/app/dns/dns.go +++ b/app/dns/dns.go @@ -12,9 +12,11 @@ import ( "sync" "time" + "github.com/amnezia-vpn/amnezia-xray-core/app/router" "github.com/amnezia-vpn/amnezia-xray-core/common" "github.com/amnezia-vpn/amnezia-xray-core/common/errors" "github.com/amnezia-vpn/amnezia-xray-core/common/net" + "github.com/amnezia-vpn/amnezia-xray-core/common/platform" "github.com/amnezia-vpn/amnezia-xray-core/common/session" "github.com/amnezia-vpn/amnezia-xray-core/common/strmatcher" "github.com/amnezia-vpn/amnezia-xray-core/features/dns" @@ -83,9 +85,31 @@ func New(ctx context.Context, config *Config) (*DNS, error) { return nil, errors.New("unexpected query strategy ", config.QueryStrategy) } - hosts, err := NewStaticHosts(config.StaticHosts) - if err != nil { - return nil, errors.New("failed to create hosts").Base(err) + var hosts *StaticHosts + mphLoaded := false + domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" }) + if domainMatcherPath != "" { + if f, err := os.Open(domainMatcherPath); err == nil { + defer f.Close() + if m, err := router.LoadGeoSiteMatcher(f, "HOSTS"); err == nil { + f.Seek(0, 0) + if hostIPs, err := router.LoadGeoSiteHosts(f); err == nil { + if sh, err := NewStaticHostsFromCache(m, hostIPs); err == nil { + hosts = sh + mphLoaded = true + errors.LogDebug(ctx, "MphDomainMatcher loaded from cache for DNS hosts, size: ", sh.matchers.Size()) + } + } + } + } + } + + if !mphLoaded { + sh, err := NewStaticHosts(config.StaticHosts) + if err != nil { + return nil, errors.New("failed to create hosts").Base(err) + } + hosts = sh } var clients []*Client diff --git a/app/dns/hosts.go b/app/dns/hosts.go index d287df24..c0e9c101 100644 --- a/app/dns/hosts.go +++ b/app/dns/hosts.go @@ -14,7 +14,7 @@ import ( // StaticHosts represents static domain-ip mapping in DNS server. type StaticHosts struct { ips [][]net.Address - matchers *strmatcher.MatcherGroup + matchers strmatcher.IndexMatcher } // NewStaticHosts creates a new StaticHosts instance. @@ -124,3 +124,50 @@ func (h *StaticHosts) lookup(domain string, option dns.IPOption, maxDepth int) ( func (h *StaticHosts) Lookup(domain string, option dns.IPOption) ([]net.Address, error) { return h.lookup(domain, option, 5) } +func NewStaticHostsFromCache(matcher strmatcher.IndexMatcher, hostIPs map[string][]string) (*StaticHosts, error) { + sh := &StaticHosts{ + ips: make([][]net.Address, matcher.Size()+1), + matchers: matcher, + } + + order := hostIPs["_ORDER"] + var offset uint32 + + img, ok := matcher.(*strmatcher.IndexMatcherGroup) + if !ok { + // Single matcher (e.g. only manual or only one geosite) + if len(order) > 0 { + pattern := order[0] + ips := parseIPs(hostIPs[pattern]) + for i := uint32(1); i <= matcher.Size(); i++ { + sh.ips[i] = ips + } + } + return sh, nil + } + + for i, m := range img.Matchers { + if i < len(order) { + pattern := order[i] + ips := parseIPs(hostIPs[pattern]) + for j := uint32(1); j <= m.Size(); j++ { + sh.ips[offset+j] = ips + } + offset += m.Size() + } + } + return sh, nil +} + +func parseIPs(raw []string) []net.Address { + addrs := make([]net.Address, 0, len(raw)) + for _, s := range raw { + if len(s) > 1 && s[0] == '#' { + rcode, _ := strconv.Atoi(s[1:]) + addrs = append(addrs, dns.RCodeError(rcode)) + } else { + addrs = append(addrs, net.ParseAddress(s)) + } + } + return addrs +} diff --git a/app/dns/hosts_test.go b/app/dns/hosts_test.go index 21f25db2..4b1ce5d4 100644 --- a/app/dns/hosts_test.go +++ b/app/dns/hosts_test.go @@ -1,9 +1,11 @@ package dns_test import ( + "bytes" "testing" . "github.com/amnezia-vpn/amnezia-xray-core/app/dns" + "github.com/amnezia-vpn/amnezia-xray-core/app/router" "github.com/amnezia-vpn/amnezia-xray-core/common" "github.com/amnezia-vpn/amnezia-xray-core/common/net" "github.com/amnezia-vpn/amnezia-xray-core/features/dns" @@ -130,3 +132,57 @@ func TestStaticHosts(t *testing.T) { } } } +func TestStaticHostsFromCache(t *testing.T) { + sites := []*router.GeoSite{ + { + CountryCode: "cloudflare-dns.com", + Domain: []*router.Domain{ + {Type: router.Domain_Full, Value: "example.com"}, + }, + }, + { + CountryCode: "geosite:cn", + Domain: []*router.Domain{ + {Type: router.Domain_Domain, Value: "baidu.cn"}, + }, + }, + } + deps := map[string][]string{ + "HOSTS": {"cloudflare-dns.com", "geosite:cn"}, + } + hostIPs := map[string][]string{ + "cloudflare-dns.com": {"1.1.1.1"}, + "geosite:cn": {"2.2.2.2"}, + "_ORDER": {"cloudflare-dns.com", "geosite:cn"}, + } + + var buf bytes.Buffer + err := router.SerializeGeoSiteList(sites, deps, hostIPs, &buf) + common.Must(err) + + // Load matcher + m, err := router.LoadGeoSiteMatcher(bytes.NewReader(buf.Bytes()), "HOSTS") + common.Must(err) + + // Load hostIPs + f := bytes.NewReader(buf.Bytes()) + hips, err := router.LoadGeoSiteHosts(f) + common.Must(err) + + hosts, err := NewStaticHostsFromCache(m, hips) + common.Must(err) + + { + ips, _ := hosts.Lookup("example.com", dns.IPOption{IPv4Enable: true}) + if len(ips) != 1 || ips[0].String() != "1.1.1.1" { + t.Error("failed to lookup example.com from cache") + } + } + + { + ips, _ := hosts.Lookup("baidu.cn", dns.IPOption{IPv4Enable: true}) + if len(ips) != 1 || ips[0].String() != "2.2.2.2" { + t.Error("failed to lookup baidu.cn from cache deps") + } + } +} diff --git a/app/dns/nameserver.go b/app/dns/nameserver.go index a9badd81..7f27e516 100644 --- a/app/dns/nameserver.go +++ b/app/dns/nameserver.go @@ -10,6 +10,8 @@ import ( "github.com/amnezia-vpn/amnezia-xray-core/app/router" "github.com/amnezia-vpn/amnezia-xray-core/common/errors" "github.com/amnezia-vpn/amnezia-xray-core/common/net" + "github.com/amnezia-vpn/amnezia-xray-core/common/platform" + "github.com/amnezia-vpn/amnezia-xray-core/common/platform/filesystem" "github.com/amnezia-vpn/amnezia-xray-core/common/session" "github.com/amnezia-vpn/amnezia-xray-core/common/strmatcher" "github.com/amnezia-vpn/amnezia-xray-core/core" @@ -17,6 +19,18 @@ import ( "github.com/amnezia-vpn/amnezia-xray-core/features/routing" ) +type mphMatcherWrapper struct { + m strmatcher.IndexMatcher +} + +func (w *mphMatcherWrapper) Match(s string) bool { + return w.m.Match(s) != nil +} + +func (w *mphMatcherWrapper) String() string { + return "mph-matcher" +} + // Server is the interface for Name Server. type Server interface { // Name of the Client. @@ -132,29 +146,50 @@ func NewClient( var rules []string ruleCurr := 0 ruleIter := 0 - for i, domain := range ns.PrioritizedDomain { - ns.PrioritizedDomain[i] = nil - domainRule, err := toStrMatcher(domain.Type, domain.Domain) - if err != nil { - errors.LogErrorInner(ctx, err, "failed to create domain matcher, ignore domain rule [type: ", domain.Type, ", domain: ", domain.Domain, "]") - domainRule, _ = toStrMatcher(DomainMatchingType_Full, "hack.fix.index.for.illegal.domain.rule") - } - originalRuleIdx := ruleCurr - if ruleCurr < len(ns.OriginalRules) { - rule := ns.OriginalRules[ruleCurr] - if ruleCurr >= len(rules) { - rules = append(rules, rule.Rule) + + // Check if domain matcher cache is provided via environment + domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" }) + var mphLoaded bool + + if domainMatcherPath != "" && ns.Tag != "" { + f, err := filesystem.NewFileReader(domainMatcherPath) + if err == nil { + defer f.Close() + g, err := router.LoadGeoSiteMatcher(f, ns.Tag) + if err == nil { + errors.LogDebug(ctx, "MphDomainMatcher loaded from cache for ", ns.Tag, " dns tag)") + updateDomainRule(&mphMatcherWrapper{m: g}, 0, *matcherInfos) + rules = append(rules, "[MPH Cache]") + mphLoaded = true } - ruleIter++ - if ruleIter >= int(rule.Size) { - ruleIter = 0 + } + } + + if !mphLoaded { + for i, domain := range ns.PrioritizedDomain { + ns.PrioritizedDomain[i] = nil + domainRule, err := toStrMatcher(domain.Type, domain.Domain) + if err != nil { + errors.LogErrorInner(ctx, err, "failed to create domain matcher, ignore domain rule [type: ", domain.Type, ", domain: ", domain.Domain, "]") + domainRule, _ = toStrMatcher(DomainMatchingType_Full, "hack.fix.index.for.illegal.domain.rule") + } + originalRuleIdx := ruleCurr + if ruleCurr < len(ns.OriginalRules) { + rule := ns.OriginalRules[ruleCurr] + if ruleCurr >= len(rules) { + rules = append(rules, rule.Rule) + } + ruleIter++ + if ruleIter >= int(rule.Size) { + ruleIter = 0 + ruleCurr++ + } + } else { // No original rule, generate one according to current domain matcher (majorly for compatibility with tests) + rules = append(rules, domainRule.String()) ruleCurr++ } - } else { // No original rule, generate one according to current domain matcher (majorly for compatibility with tests) - rules = append(rules, domainRule.String()) - ruleCurr++ + updateDomainRule(domainRule, originalRuleIdx, *matcherInfos) } - updateDomainRule(domainRule, originalRuleIdx, *matcherInfos) } ns.PrioritizedDomain = nil runtime.GC() diff --git a/app/dns/nameserver_doh.go b/app/dns/nameserver_doh.go index 16b6a210..1780f97e 100644 --- a/app/dns/nameserver_doh.go +++ b/app/dns/nameserver_doh.go @@ -8,7 +8,6 @@ import ( "io" "net/http" "net/url" - "strings" "time" "github.com/amnezia-vpn/amnezia-xray-core/common" @@ -19,6 +18,7 @@ import ( "github.com/amnezia-vpn/amnezia-xray-core/common/net/cnc" "github.com/amnezia-vpn/amnezia-xray-core/common/protocol/dns" "github.com/amnezia-vpn/amnezia-xray-core/common/session" + "github.com/amnezia-vpn/amnezia-xray-core/common/utils" dns_feature "github.com/amnezia-vpn/amnezia-xray-core/features/dns" "github.com/amnezia-vpn/amnezia-xray-core/features/routing" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet" @@ -214,8 +214,8 @@ func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte, req.Header.Add("Accept", "application/dns-message") req.Header.Add("Content-Type", "application/dns-message") - - req.Header.Set("X-Padding", strings.Repeat("X", int(crypto.RandBetween(100, 1000)))) + req.Header.Set("User-Agent", utils.ChromeUA) + req.Header.Set("X-Padding", utils.H2Base62Pad(crypto.RandBetween(100, 1000))) hc := s.httpClient diff --git a/app/log/command/config_grpc.pb.go b/app/log/command/config_grpc.pb.go index 492e998d..18cdee57 100644 --- a/app/log/command/config_grpc.pb.go +++ b/app/log/command/config_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.5.1 -// - protoc v5.28.2 +// - protoc-gen-go-grpc v1.6.1 +// - protoc v3.21.12 // source: app/log/command/config.proto package command @@ -63,7 +63,7 @@ type LoggerServiceServer interface { type UnimplementedLoggerServiceServer struct{} func (UnimplementedLoggerServiceServer) RestartLogger(context.Context, *RestartLoggerRequest) (*RestartLoggerResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method RestartLogger not implemented") + return nil, status.Error(codes.Unimplemented, "method RestartLogger not implemented") } func (UnimplementedLoggerServiceServer) mustEmbedUnimplementedLoggerServiceServer() {} func (UnimplementedLoggerServiceServer) testEmbeddedByValue() {} @@ -76,7 +76,7 @@ type UnsafeLoggerServiceServer interface { } func RegisterLoggerServiceServer(s grpc.ServiceRegistrar, srv LoggerServiceServer) { - // If the following call pancis, it indicates UnimplementedLoggerServiceServer was + // If the following call panics, it indicates UnimplementedLoggerServiceServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. diff --git a/app/observatory/burst/ping.go b/app/observatory/burst/ping.go index cdc72bfb..1318d309 100644 --- a/app/observatory/burst/ping.go +++ b/app/observatory/burst/ping.go @@ -7,6 +7,7 @@ import ( "time" "github.com/amnezia-vpn/amnezia-xray-core/common/net" + "github.com/amnezia-vpn/amnezia-xray-core/common/utils" "github.com/amnezia-vpn/amnezia-xray-core/features/routing" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/tagged" ) @@ -61,6 +62,7 @@ func (s *pingClient) MeasureDelay(httpMethod string) (time.Duration, error) { if err != nil { return rttFailed, err } + req.Header.Set("User-Agent", utils.ChromeUA) start := time.Now() resp, err := s.httpClient.Do(req) diff --git a/app/observatory/command/command_grpc.pb.go b/app/observatory/command/command_grpc.pb.go index 3b1e3be7..a6bd6359 100644 --- a/app/observatory/command/command_grpc.pb.go +++ b/app/observatory/command/command_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.5.1 -// - protoc v5.28.2 +// - protoc-gen-go-grpc v1.6.1 +// - protoc v3.21.12 // source: app/observatory/command/command.proto package command @@ -63,7 +63,7 @@ type ObservatoryServiceServer interface { type UnimplementedObservatoryServiceServer struct{} func (UnimplementedObservatoryServiceServer) GetOutboundStatus(context.Context, *GetOutboundStatusRequest) (*GetOutboundStatusResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetOutboundStatus not implemented") + return nil, status.Error(codes.Unimplemented, "method GetOutboundStatus not implemented") } func (UnimplementedObservatoryServiceServer) mustEmbedUnimplementedObservatoryServiceServer() {} func (UnimplementedObservatoryServiceServer) testEmbeddedByValue() {} @@ -76,7 +76,7 @@ type UnsafeObservatoryServiceServer interface { } func RegisterObservatoryServiceServer(s grpc.ServiceRegistrar, srv ObservatoryServiceServer) { - // If the following call pancis, it indicates UnimplementedObservatoryServiceServer was + // If the following call panics, it indicates UnimplementedObservatoryServiceServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. diff --git a/app/observatory/observer.go b/app/observatory/observer.go index 62202280..5ce1ec6e 100644 --- a/app/observatory/observer.go +++ b/app/observatory/observer.go @@ -15,6 +15,7 @@ import ( "github.com/amnezia-vpn/amnezia-xray-core/common/session" "github.com/amnezia-vpn/amnezia-xray-core/common/signal/done" "github.com/amnezia-vpn/amnezia-xray-core/common/task" + "github.com/amnezia-vpn/amnezia-xray-core/common/utils" "github.com/amnezia-vpn/amnezia-xray-core/core" "github.com/amnezia-vpn/amnezia-xray-core/features/extension" "github.com/amnezia-vpn/amnezia-xray-core/features/outbound" @@ -162,7 +163,9 @@ func (o *Observer) probe(outbound string) ProbeResult { if o.config.ProbeUrl != "" { probeURL = o.config.ProbeUrl } - response, err := httpClient.Get(probeURL) + req, _ := http.NewRequest(http.MethodGet, probeURL, nil) + req.Header.Set("User-Agent", utils.ChromeUA) + response, err := httpClient.Do(req) if err != nil { return errors.New("outbound failed to relay connection").Base(err) } diff --git a/app/proxyman/command/command_grpc.pb.go b/app/proxyman/command/command_grpc.pb.go index 8bc48f0a..819c68b1 100644 --- a/app/proxyman/command/command_grpc.pb.go +++ b/app/proxyman/command/command_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.5.1 -// - protoc v5.28.2 +// - protoc-gen-go-grpc v1.6.1 +// - protoc v3.21.12 // source: app/proxyman/command/command.proto package command @@ -180,34 +180,34 @@ type HandlerServiceServer interface { type UnimplementedHandlerServiceServer struct{} func (UnimplementedHandlerServiceServer) AddInbound(context.Context, *AddInboundRequest) (*AddInboundResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method AddInbound not implemented") + return nil, status.Error(codes.Unimplemented, "method AddInbound not implemented") } func (UnimplementedHandlerServiceServer) RemoveInbound(context.Context, *RemoveInboundRequest) (*RemoveInboundResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method RemoveInbound not implemented") + return nil, status.Error(codes.Unimplemented, "method RemoveInbound not implemented") } func (UnimplementedHandlerServiceServer) AlterInbound(context.Context, *AlterInboundRequest) (*AlterInboundResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method AlterInbound not implemented") + return nil, status.Error(codes.Unimplemented, "method AlterInbound not implemented") } func (UnimplementedHandlerServiceServer) ListInbounds(context.Context, *ListInboundsRequest) (*ListInboundsResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ListInbounds not implemented") + return nil, status.Error(codes.Unimplemented, "method ListInbounds not implemented") } func (UnimplementedHandlerServiceServer) GetInboundUsers(context.Context, *GetInboundUserRequest) (*GetInboundUserResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetInboundUsers not implemented") + return nil, status.Error(codes.Unimplemented, "method GetInboundUsers not implemented") } func (UnimplementedHandlerServiceServer) GetInboundUsersCount(context.Context, *GetInboundUserRequest) (*GetInboundUsersCountResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetInboundUsersCount not implemented") + return nil, status.Error(codes.Unimplemented, "method GetInboundUsersCount not implemented") } func (UnimplementedHandlerServiceServer) AddOutbound(context.Context, *AddOutboundRequest) (*AddOutboundResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method AddOutbound not implemented") + return nil, status.Error(codes.Unimplemented, "method AddOutbound not implemented") } func (UnimplementedHandlerServiceServer) RemoveOutbound(context.Context, *RemoveOutboundRequest) (*RemoveOutboundResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method RemoveOutbound not implemented") + return nil, status.Error(codes.Unimplemented, "method RemoveOutbound not implemented") } func (UnimplementedHandlerServiceServer) AlterOutbound(context.Context, *AlterOutboundRequest) (*AlterOutboundResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method AlterOutbound not implemented") + return nil, status.Error(codes.Unimplemented, "method AlterOutbound not implemented") } func (UnimplementedHandlerServiceServer) ListOutbounds(context.Context, *ListOutboundsRequest) (*ListOutboundsResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ListOutbounds not implemented") + return nil, status.Error(codes.Unimplemented, "method ListOutbounds not implemented") } func (UnimplementedHandlerServiceServer) mustEmbedUnimplementedHandlerServiceServer() {} func (UnimplementedHandlerServiceServer) testEmbeddedByValue() {} @@ -220,7 +220,7 @@ type UnsafeHandlerServiceServer interface { } func RegisterHandlerServiceServer(s grpc.ServiceRegistrar, srv HandlerServiceServer) { - // If the following call pancis, it indicates UnimplementedHandlerServiceServer was + // If the following call panics, it indicates UnimplementedHandlerServiceServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. diff --git a/app/router/command/command_grpc.pb.go b/app/router/command/command_grpc.pb.go index f134d621..0688ae74 100644 --- a/app/router/command/command_grpc.pb.go +++ b/app/router/command/command_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.6.0 -// - protoc v6.33.2 +// - protoc-gen-go-grpc v1.6.1 +// - protoc v3.21.12 // source: app/router/command/command.proto package command diff --git a/app/router/condition.go b/app/router/condition.go index ecab1591..05b66698 100644 --- a/app/router/condition.go +++ b/app/router/condition.go @@ -2,6 +2,7 @@ package router import ( "context" + "io" "os" "path/filepath" "regexp" @@ -52,7 +53,34 @@ var matcherTypeMap = map[Domain_Type]strmatcher.Type{ } type DomainMatcher struct { - matchers strmatcher.IndexMatcher + Matchers strmatcher.IndexMatcher +} + +func SerializeDomainMatcher(domains []*Domain, w io.Writer) error { + + g := strmatcher.NewMphMatcherGroup() + for _, d := range domains { + matcherType, f := matcherTypeMap[d.Type] + if !f { + continue + } + + _, err := g.AddPattern(d.Value, matcherType) + if err != nil { + return err + } + } + g.Build() + // serialize + return g.Serialize(w) +} + +func NewDomainMatcherFromBuffer(data []byte) (*strmatcher.MphMatcherGroup, error) { + matcher, err := strmatcher.NewMphMatcherGroupFromBuffer(data) + if err != nil { + return nil, err + } + return matcher, nil } func NewMphMatcherGroup(domains []*Domain) (*DomainMatcher, error) { @@ -72,12 +100,12 @@ func NewMphMatcherGroup(domains []*Domain) (*DomainMatcher, error) { } g.Build() return &DomainMatcher{ - matchers: g, + Matchers: g, }, nil } func (m *DomainMatcher) ApplyDomain(domain string) bool { - return len(m.matchers.Match(strings.ToLower(domain))) > 0 + return len(m.Matchers.Match(strings.ToLower(domain))) > 0 } // Apply implements Condition. diff --git a/app/router/condition_serialize_test.go b/app/router/condition_serialize_test.go new file mode 100644 index 00000000..cd446254 --- /dev/null +++ b/app/router/condition_serialize_test.go @@ -0,0 +1,167 @@ +package router_test + +import ( + "bytes" + "os" + "path/filepath" + "testing" + + "github.com/amnezia-vpn/amnezia-xray-core/app/router" + "github.com/amnezia-vpn/amnezia-xray-core/common/platform/filesystem" + "github.com/stretchr/testify/require" +) + +func TestDomainMatcherSerialization(t *testing.T) { + + domains := []*router.Domain{ + {Type: router.Domain_Domain, Value: "google.com"}, + {Type: router.Domain_Domain, Value: "v2ray.com"}, + {Type: router.Domain_Full, Value: "full.example.com"}, + } + + var buf bytes.Buffer + if err := router.SerializeDomainMatcher(domains, &buf); err != nil { + t.Fatalf("Serialize failed: %v", err) + } + + matcher, err := router.NewDomainMatcherFromBuffer(buf.Bytes()) + if err != nil { + t.Fatalf("Deserialize failed: %v", err) + } + + dMatcher := &router.DomainMatcher{ + Matchers: matcher, + } + testCases := []struct { + Input string + Match bool + }{ + {"google.com", true}, + {"maps.google.com", true}, + {"v2ray.com", true}, + {"full.example.com", true}, + + {"example.com", false}, + } + + for _, tc := range testCases { + if res := dMatcher.ApplyDomain(tc.Input); res != tc.Match { + t.Errorf("Match(%s) = %v, want %v", tc.Input, res, tc.Match) + } + } +} + +func TestGeoSiteSerialization(t *testing.T) { + sites := []*router.GeoSite{ + { + CountryCode: "CN", + Domain: []*router.Domain{ + {Type: router.Domain_Domain, Value: "baidu.cn"}, + {Type: router.Domain_Domain, Value: "qq.com"}, + }, + }, + { + CountryCode: "US", + Domain: []*router.Domain{ + {Type: router.Domain_Domain, Value: "google.com"}, + {Type: router.Domain_Domain, Value: "facebook.com"}, + }, + }, + } + + var buf bytes.Buffer + if err := router.SerializeGeoSiteList(sites, nil, nil, &buf); err != nil { + t.Fatalf("SerializeGeoSiteList failed: %v", err) + } + + tmp := t.TempDir() + path := filepath.Join(tmp, "matcher.cache") + + f, err := os.Create(path) + require.NoError(t, err) + _, err = f.Write(buf.Bytes()) + require.NoError(t, err) + f.Close() + + f, err = os.Open(path) + require.NoError(t, err) + defer f.Close() + + require.NoError(t, err) + data, _ := filesystem.ReadFile(path) + + // cn + gp, err := router.LoadGeoSiteMatcher(bytes.NewReader(data), "CN") + if err != nil { + t.Fatalf("LoadGeoSiteMatcher(CN) failed: %v", err) + } + + cnMatcher := &router.DomainMatcher{ + Matchers: gp, + } + + if !cnMatcher.ApplyDomain("baidu.cn") { + t.Error("CN matcher should match baidu.cn") + } + if cnMatcher.ApplyDomain("google.com") { + t.Error("CN matcher should NOT match google.com") + } + + // us + gp, err = router.LoadGeoSiteMatcher(bytes.NewReader(data), "US") + if err != nil { + t.Fatalf("LoadGeoSiteMatcher(US) failed: %v", err) + } + + usMatcher := &router.DomainMatcher{ + Matchers: gp, + } + if !usMatcher.ApplyDomain("google.com") { + t.Error("US matcher should match google.com") + } + if usMatcher.ApplyDomain("baidu.cn") { + t.Error("US matcher should NOT match baidu.cn") + } + + // unknown + _, err = router.LoadGeoSiteMatcher(bytes.NewReader(data), "unknown") + if err == nil { + t.Error("LoadGeoSiteMatcher(unknown) should fail") + } +} +func TestGeoSiteSerializationWithDeps(t *testing.T) { + sites := []*router.GeoSite{ + { + CountryCode: "geosite:cn", + Domain: []*router.Domain{ + {Type: router.Domain_Domain, Value: "baidu.cn"}, + }, + }, + { + CountryCode: "geosite:google@cn", + Domain: []*router.Domain{ + {Type: router.Domain_Domain, Value: "google.cn"}, + }, + }, + { + CountryCode: "rule-1", + Domain: []*router.Domain{ + {Type: router.Domain_Domain, Value: "google.com"}, + }, + }, + } + deps := map[string][]string{ + "rule-1": {"geosite:cn", "geosite:google@cn"}, + } + + var buf bytes.Buffer + err := router.SerializeGeoSiteList(sites, deps, nil, &buf) + require.NoError(t, err) + + matcher, err := router.LoadGeoSiteMatcher(bytes.NewReader(buf.Bytes()), "rule-1") + require.NoError(t, err) + + require.True(t, matcher.Match("google.com") != nil) + require.True(t, matcher.Match("baidu.cn") != nil) + require.True(t, matcher.Match("google.cn") != nil) +} diff --git a/app/router/config.go b/app/router/config.go index 2c3cb93a..92875cbc 100644 --- a/app/router/config.go +++ b/app/router/config.go @@ -7,6 +7,8 @@ import ( "strings" "github.com/amnezia-vpn/amnezia-xray-core/common/errors" + "github.com/amnezia-vpn/amnezia-xray-core/common/platform" + "github.com/amnezia-vpn/amnezia-xray-core/common/platform/filesystem" "github.com/amnezia-vpn/amnezia-xray-core/features/outbound" "github.com/amnezia-vpn/amnezia-xray-core/features/routing" ) @@ -105,11 +107,25 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { } if len(rr.Domain) > 0 { - matcher, err := NewMphMatcherGroup(rr.Domain) - if err != nil { - return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err) + var matcher *DomainMatcher + var err error + // Check if domain matcher cache is provided via environment + domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" }) + + if domainMatcherPath != "" { + matcher, err = GetDomainMatcherWithRuleTag(domainMatcherPath, rr.RuleTag) + if err != nil { + return nil, errors.New("failed to build domain condition from cached MphDomainMatcher").Base(err) + } + errors.LogDebug(context.Background(), "MphDomainMatcher loaded from cache for ", rr.RuleTag, " rule tag)") + + } else { + matcher, err = NewMphMatcherGroup(rr.Domain) + if err != nil { + return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err) + } + errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)") } - errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)") conds.Add(matcher) rr.Domain = nil runtime.GC() @@ -172,3 +188,20 @@ func (br *BalancingRule) Build(ohm outbound.Manager, dispatcher routing.Dispatch return nil, errors.New("unrecognized balancer type") } } + +func GetDomainMatcherWithRuleTag(domainMatcherPath string, ruleTag string) (*DomainMatcher, error) { + f, err := filesystem.NewFileReader(domainMatcherPath) + if err != nil { + return nil, errors.New("failed to load file: ", domainMatcherPath).Base(err) + } + defer f.Close() + + g, err := LoadGeoSiteMatcher(f, ruleTag) + if err != nil { + return nil, errors.New("failed to load file:", domainMatcherPath).Base(err) + } + return &DomainMatcher{ + Matchers: g, + }, nil + +} diff --git a/app/router/geosite_compact.go b/app/router/geosite_compact.go new file mode 100644 index 00000000..26907eee --- /dev/null +++ b/app/router/geosite_compact.go @@ -0,0 +1,100 @@ +package router + +import ( + "encoding/gob" + "errors" + "io" + "runtime" + + "github.com/amnezia-vpn/amnezia-xray-core/common/strmatcher" +) + +type geoSiteListGob struct { + Sites map[string][]byte + Deps map[string][]string + Hosts map[string][]string +} + +func SerializeGeoSiteList(sites []*GeoSite, deps map[string][]string, hosts map[string][]string, w io.Writer) error { + data := geoSiteListGob{ + Sites: make(map[string][]byte), + Deps: deps, + Hosts: hosts, + } + + for _, site := range sites { + if site == nil { + continue + } + var buf bytesWriter + if err := SerializeDomainMatcher(site.Domain, &buf); err != nil { + return err + } + data.Sites[site.CountryCode] = buf.Bytes() + } + + return gob.NewEncoder(w).Encode(data) +} + +type bytesWriter struct { + data []byte +} + +func (w *bytesWriter) Write(p []byte) (n int, err error) { + w.data = append(w.data, p...) + return len(p), nil +} + +func (w *bytesWriter) Bytes() []byte { + return w.data +} + +func LoadGeoSiteMatcher(r io.Reader, countryCode string) (strmatcher.IndexMatcher, error) { + var data geoSiteListGob + if err := gob.NewDecoder(r).Decode(&data); err != nil { + return nil, err + } + + return loadWithDeps(&data, countryCode, make(map[string]bool)) +} + +func loadWithDeps(data *geoSiteListGob, code string, visited map[string]bool) (strmatcher.IndexMatcher, error) { + if visited[code] { + return nil, errors.New("cyclic dependency") + } + visited[code] = true + + var matchers []strmatcher.IndexMatcher + + if siteData, ok := data.Sites[code]; ok { + m, err := NewDomainMatcherFromBuffer(siteData) + if err == nil { + matchers = append(matchers, m) + } + } + + if deps, ok := data.Deps[code]; ok { + for _, dep := range deps { + m, err := loadWithDeps(data, dep, visited) + if err == nil { + matchers = append(matchers, m) + } + } + } + + if len(matchers) == 0 { + return nil, errors.New("matcher not found for: " + code) + } + if len(matchers) == 1 { + return matchers[0], nil + } + runtime.GC() + return &strmatcher.IndexMatcherGroup{Matchers: matchers}, nil +} +func LoadGeoSiteHosts(r io.Reader) (map[string][]string, error) { + var data geoSiteListGob + if err := gob.NewDecoder(r).Decode(&data); err != nil { + return nil, err + } + return data.Hosts, nil +} diff --git a/app/stats/command/command_grpc.pb.go b/app/stats/command/command_grpc.pb.go index 9864e7f7..4a5f8889 100644 --- a/app/stats/command/command_grpc.pb.go +++ b/app/stats/command/command_grpc.pb.go @@ -1,8 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.5.1 -// - protoc v6.32.0 -// source: command.proto +// - protoc-gen-go-grpc v1.6.1 +// - protoc v3.21.12 +// source: app/stats/command/command.proto package command @@ -128,22 +128,22 @@ type StatsServiceServer interface { type UnimplementedStatsServiceServer struct{} func (UnimplementedStatsServiceServer) GetStats(context.Context, *GetStatsRequest) (*GetStatsResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetStats not implemented") + return nil, status.Error(codes.Unimplemented, "method GetStats not implemented") } func (UnimplementedStatsServiceServer) GetStatsOnline(context.Context, *GetStatsRequest) (*GetStatsResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetStatsOnline not implemented") + return nil, status.Error(codes.Unimplemented, "method GetStatsOnline not implemented") } func (UnimplementedStatsServiceServer) QueryStats(context.Context, *QueryStatsRequest) (*QueryStatsResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method QueryStats not implemented") + return nil, status.Error(codes.Unimplemented, "method QueryStats not implemented") } func (UnimplementedStatsServiceServer) GetSysStats(context.Context, *SysStatsRequest) (*SysStatsResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetSysStats not implemented") + return nil, status.Error(codes.Unimplemented, "method GetSysStats not implemented") } func (UnimplementedStatsServiceServer) GetStatsOnlineIpList(context.Context, *GetStatsRequest) (*GetStatsOnlineIpListResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetStatsOnlineIpList not implemented") + return nil, status.Error(codes.Unimplemented, "method GetStatsOnlineIpList not implemented") } func (UnimplementedStatsServiceServer) GetAllOnlineUsers(context.Context, *GetAllOnlineUsersRequest) (*GetAllOnlineUsersResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetAllOnlineUsers not implemented") + return nil, status.Error(codes.Unimplemented, "method GetAllOnlineUsers not implemented") } func (UnimplementedStatsServiceServer) mustEmbedUnimplementedStatsServiceServer() {} func (UnimplementedStatsServiceServer) testEmbeddedByValue() {} @@ -156,7 +156,7 @@ type UnsafeStatsServiceServer interface { } func RegisterStatsServiceServer(s grpc.ServiceRegistrar, srv StatsServiceServer) { - // If the following call pancis, it indicates UnimplementedStatsServiceServer was + // If the following call panics, it indicates UnimplementedStatsServiceServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. @@ -307,5 +307,5 @@ var StatsService_ServiceDesc = grpc.ServiceDesc{ }, }, Streams: []grpc.StreamDesc{}, - Metadata: "command.proto", + Metadata: "app/stats/command/command.proto", } diff --git a/common/platform/platform.go b/common/platform/platform.go index 40750286..f4655294 100644 --- a/common/platform/platform.go +++ b/common/platform/platform.go @@ -24,6 +24,8 @@ const ( XUDPBaseKey = "xray.xudp.basekey" TunFdKey = "xray.tun.fd" + + MphCachePath = "xray.mph.cache" ) type EnvFlag struct { diff --git a/common/strmatcher/ac_automaton_matcher.go b/common/strmatcher/ac_automaton_matcher.go index 24be9dac..7844333d 100644 --- a/common/strmatcher/ac_automaton_matcher.go +++ b/common/strmatcher/ac_automaton_matcher.go @@ -7,8 +7,8 @@ import ( const validCharCount = 53 type MatchType struct { - matchType Type - exist bool + Type Type + Exist bool } const ( @@ -17,23 +17,23 @@ const ( ) type Edge struct { - edgeType bool - nextNode int + Type bool + NextNode int } type ACAutomaton struct { - trie [][validCharCount]Edge - fail []int - exists []MatchType - count int + Trie [][validCharCount]Edge + Fail []int + Exists []MatchType + Count int } func newNode() [validCharCount]Edge { var s [validCharCount]Edge for i := range s { s[i] = Edge{ - edgeType: FailEdge, - nextNode: 0, + Type: FailEdge, + NextNode: 0, } } return s @@ -123,11 +123,11 @@ var char2Index = []int{ func NewACAutomaton() *ACAutomaton { ac := new(ACAutomaton) - ac.trie = append(ac.trie, newNode()) - ac.fail = append(ac.fail, 0) - ac.exists = append(ac.exists, MatchType{ - matchType: Full, - exist: false, + ac.Trie = append(ac.Trie, newNode()) + ac.Fail = append(ac.Fail, 0) + ac.Exists = append(ac.Exists, MatchType{ + Type: Full, + Exist: false, }) return ac } @@ -136,53 +136,53 @@ func (ac *ACAutomaton) Add(domain string, t Type) { node := 0 for i := len(domain) - 1; i >= 0; i-- { idx := char2Index[domain[i]] - if ac.trie[node][idx].nextNode == 0 { - ac.count++ - if len(ac.trie) < ac.count+1 { - ac.trie = append(ac.trie, newNode()) - ac.fail = append(ac.fail, 0) - ac.exists = append(ac.exists, MatchType{ - matchType: Full, - exist: false, + if ac.Trie[node][idx].NextNode == 0 { + ac.Count++ + if len(ac.Trie) < ac.Count+1 { + ac.Trie = append(ac.Trie, newNode()) + ac.Fail = append(ac.Fail, 0) + ac.Exists = append(ac.Exists, MatchType{ + Type: Full, + Exist: false, }) } - ac.trie[node][idx] = Edge{ - edgeType: TrieEdge, - nextNode: ac.count, + ac.Trie[node][idx] = Edge{ + Type: TrieEdge, + NextNode: ac.Count, } } - node = ac.trie[node][idx].nextNode + node = ac.Trie[node][idx].NextNode } - ac.exists[node] = MatchType{ - matchType: t, - exist: true, + ac.Exists[node] = MatchType{ + Type: t, + Exist: true, } switch t { case Domain: - ac.exists[node] = MatchType{ - matchType: Full, - exist: true, + ac.Exists[node] = MatchType{ + Type: Full, + Exist: true, } idx := char2Index['.'] - if ac.trie[node][idx].nextNode == 0 { - ac.count++ - if len(ac.trie) < ac.count+1 { - ac.trie = append(ac.trie, newNode()) - ac.fail = append(ac.fail, 0) - ac.exists = append(ac.exists, MatchType{ - matchType: Full, - exist: false, + if ac.Trie[node][idx].NextNode == 0 { + ac.Count++ + if len(ac.Trie) < ac.Count+1 { + ac.Trie = append(ac.Trie, newNode()) + ac.Fail = append(ac.Fail, 0) + ac.Exists = append(ac.Exists, MatchType{ + Type: Full, + Exist: false, }) } - ac.trie[node][idx] = Edge{ - edgeType: TrieEdge, - nextNode: ac.count, + ac.Trie[node][idx] = Edge{ + Type: TrieEdge, + NextNode: ac.Count, } } - node = ac.trie[node][idx].nextNode - ac.exists[node] = MatchType{ - matchType: t, - exist: true, + node = ac.Trie[node][idx].NextNode + ac.Exists[node] = MatchType{ + Type: t, + Exist: true, } default: break @@ -192,8 +192,8 @@ func (ac *ACAutomaton) Add(domain string, t Type) { func (ac *ACAutomaton) Build() { queue := list.New() for i := 0; i < validCharCount; i++ { - if ac.trie[0][i].nextNode != 0 { - queue.PushBack(ac.trie[0][i]) + if ac.Trie[0][i].NextNode != 0 { + queue.PushBack(ac.Trie[0][i]) } } for { @@ -201,16 +201,16 @@ func (ac *ACAutomaton) Build() { if front == nil { break } else { - node := front.Value.(Edge).nextNode + node := front.Value.(Edge).NextNode queue.Remove(front) for i := 0; i < validCharCount; i++ { - if ac.trie[node][i].nextNode != 0 { - ac.fail[ac.trie[node][i].nextNode] = ac.trie[ac.fail[node]][i].nextNode - queue.PushBack(ac.trie[node][i]) + if ac.Trie[node][i].NextNode != 0 { + ac.Fail[ac.Trie[node][i].NextNode] = ac.Trie[ac.Fail[node]][i].NextNode + queue.PushBack(ac.Trie[node][i]) } else { - ac.trie[node][i] = Edge{ - edgeType: FailEdge, - nextNode: ac.trie[ac.fail[node]][i].nextNode, + ac.Trie[node][i] = Edge{ + Type: FailEdge, + NextNode: ac.Trie[ac.Fail[node]][i].NextNode, } } } @@ -230,9 +230,9 @@ func (ac *ACAutomaton) Match(s string) bool { return false } idx := char2Index[chr] - fullMatch = fullMatch && ac.trie[node][idx].edgeType - node = ac.trie[node][idx].nextNode - switch ac.exists[node].matchType { + fullMatch = fullMatch && ac.Trie[node][idx].Type + node = ac.Trie[node][idx].NextNode + switch ac.Exists[node].Type { case Substr: return true case Domain: @@ -243,5 +243,5 @@ func (ac *ACAutomaton) Match(s string) bool { break } } - return fullMatch && ac.exists[node].exist + return fullMatch && ac.Exists[node].Exist } diff --git a/common/strmatcher/matchers.go b/common/strmatcher/matchers.go index b5ab09c4..915927db 100644 --- a/common/strmatcher/matchers.go +++ b/common/strmatcher/matchers.go @@ -39,14 +39,18 @@ func (m domainMatcher) String() string { return "domain:" + string(m) } -type regexMatcher struct { - pattern *regexp.Regexp +type RegexMatcher struct { + Pattern string + reg *regexp.Regexp } -func (m *regexMatcher) Match(s string) bool { - return m.pattern.MatchString(s) +func (m *RegexMatcher) Match(s string) bool { + if m.reg == nil { + m.reg = regexp.MustCompile(m.Pattern) + } + return m.reg.MatchString(s) } -func (m *regexMatcher) String() string { - return "regexp:" + m.pattern.String() +func (m *RegexMatcher) String() string { + return "regexp:" + m.Pattern } diff --git a/common/strmatcher/mph_matcher.go b/common/strmatcher/mph_matcher.go index 3c10cb49..ff3dea65 100644 --- a/common/strmatcher/mph_matcher.go +++ b/common/strmatcher/mph_matcher.go @@ -25,40 +25,40 @@ func RollingHash(s string) uint32 { // 2. `substr` patterns are matched by ac automaton; // 3. `regex` patterns are matched with the regex library. type MphMatcherGroup struct { - ac *ACAutomaton - otherMatchers []matcherEntry - rules []string - level0 []uint32 - level0Mask int - level1 []uint32 - level1Mask int - count uint32 - ruleMap *map[string]uint32 + Ac *ACAutomaton + OtherMatchers []MatcherEntry + Rules []string + Level0 []uint32 + Level0Mask int + Level1 []uint32 + Level1Mask int + Count uint32 + RuleMap *map[string]uint32 } func (g *MphMatcherGroup) AddFullOrDomainPattern(pattern string, t Type) { h := RollingHash(pattern) switch t { case Domain: - (*g.ruleMap)["."+pattern] = h*PrimeRK + uint32('.') + (*g.RuleMap)["."+pattern] = h*PrimeRK + uint32('.') fallthrough case Full: - (*g.ruleMap)[pattern] = h + (*g.RuleMap)[pattern] = h default: } } func NewMphMatcherGroup() *MphMatcherGroup { return &MphMatcherGroup{ - ac: nil, - otherMatchers: nil, - rules: nil, - level0: nil, - level0Mask: 0, - level1: nil, - level1Mask: 0, - count: 1, - ruleMap: &map[string]uint32{}, + Ac: nil, + OtherMatchers: nil, + Rules: nil, + Level0: nil, + Level0Mask: 0, + Level1: nil, + Level1Mask: 0, + Count: 1, + RuleMap: &map[string]uint32{}, } } @@ -66,10 +66,10 @@ func NewMphMatcherGroup() *MphMatcherGroup { func (g *MphMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) { switch t { case Substr: - if g.ac == nil { - g.ac = NewACAutomaton() + if g.Ac == nil { + g.Ac = NewACAutomaton() } - g.ac.Add(pattern, t) + g.Ac.Add(pattern, t) case Full, Domain: pattern = strings.ToLower(pattern) g.AddFullOrDomainPattern(pattern, t) @@ -78,39 +78,39 @@ func (g *MphMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) { if err != nil { return 0, err } - g.otherMatchers = append(g.otherMatchers, matcherEntry{ - m: ®exMatcher{pattern: r}, - id: g.count, + g.OtherMatchers = append(g.OtherMatchers, MatcherEntry{ + M: &RegexMatcher{Pattern: pattern, reg: r}, + Id: g.Count, }) default: panic("Unknown type") } - return g.count, nil + return g.Count, nil } // Build builds a minimal perfect hash table and ac automaton from insert rules func (g *MphMatcherGroup) Build() { - if g.ac != nil { - g.ac.Build() + if g.Ac != nil { + g.Ac.Build() } - keyLen := len(*g.ruleMap) + keyLen := len(*g.RuleMap) if keyLen == 0 { keyLen = 1 - (*g.ruleMap)["empty___"] = RollingHash("empty___") + (*g.RuleMap)["empty___"] = RollingHash("empty___") } - g.level0 = make([]uint32, nextPow2(keyLen/4)) - g.level0Mask = len(g.level0) - 1 - g.level1 = make([]uint32, nextPow2(keyLen)) - g.level1Mask = len(g.level1) - 1 - sparseBuckets := make([][]int, len(g.level0)) + g.Level0 = make([]uint32, nextPow2(keyLen/4)) + g.Level0Mask = len(g.Level0) - 1 + g.Level1 = make([]uint32, nextPow2(keyLen)) + g.Level1Mask = len(g.Level1) - 1 + sparseBuckets := make([][]int, len(g.Level0)) var ruleIdx int - for rule, hash := range *g.ruleMap { - n := int(hash) & g.level0Mask - g.rules = append(g.rules, rule) + for rule, hash := range *g.RuleMap { + n := int(hash) & g.Level0Mask + g.Rules = append(g.Rules, rule) sparseBuckets[n] = append(sparseBuckets[n], ruleIdx) ruleIdx++ } - g.ruleMap = nil + g.RuleMap = nil var buckets []indexBucket for n, vals := range sparseBuckets { if len(vals) > 0 { @@ -119,7 +119,7 @@ func (g *MphMatcherGroup) Build() { } sort.Sort(bySize(buckets)) - occ := make([]bool, len(g.level1)) + occ := make([]bool, len(g.Level1)) var tmpOcc []int for _, bucket := range buckets { seed := uint32(0) @@ -127,7 +127,7 @@ func (g *MphMatcherGroup) Build() { findSeed := true tmpOcc = tmpOcc[:0] for _, i := range bucket.vals { - n := int(strhashFallback(unsafe.Pointer(&g.rules[i]), uintptr(seed))) & g.level1Mask + n := int(strhashFallback(unsafe.Pointer(&g.Rules[i]), uintptr(seed))) & g.Level1Mask if occ[n] { for _, n := range tmpOcc { occ[n] = false @@ -138,10 +138,10 @@ func (g *MphMatcherGroup) Build() { } occ[n] = true tmpOcc = append(tmpOcc, n) - g.level1[n] = uint32(i) + g.Level1[n] = uint32(i) } if findSeed { - g.level0[bucket.n] = seed + g.Level0[bucket.n] = seed break } } @@ -159,11 +159,11 @@ func nextPow2(v int) int { // Lookup searches for s in t and returns its index and whether it was found. func (g *MphMatcherGroup) Lookup(h uint32, s string) bool { - i0 := int(h) & g.level0Mask - seed := g.level0[i0] - i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.level1Mask - n := g.level1[i1] - return s == g.rules[int(n)] + i0 := int(h) & g.Level0Mask + seed := g.Level0[i0] + i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.Level1Mask + n := g.Level1[i1] + return s == g.Rules[int(n)] } // Match implements IndexMatcher.Match. @@ -183,13 +183,13 @@ func (g *MphMatcherGroup) Match(pattern string) []uint32 { result = append(result, 1) return result } - if g.ac != nil && g.ac.Match(pattern) { + if g.Ac != nil && g.Ac.Match(pattern) { result = append(result, 1) return result } - for _, e := range g.otherMatchers { - if e.m.Match(pattern) { - result = append(result, e.id) + for _, e := range g.OtherMatchers { + if e.M.Match(pattern) { + result = append(result, e.Id) return result } } @@ -302,3 +302,7 @@ func readUnaligned64(p unsafe.Pointer) uint64 { q := (*[8]byte)(p) return uint64(q[0]) | uint64(q[1])<<8 | uint64(q[2])<<16 | uint64(q[3])<<24 | uint64(q[4])<<32 | uint64(q[5])<<40 | uint64(q[6])<<48 | uint64(q[7])<<56 } + +func (g *MphMatcherGroup) Size() uint32 { + return g.Count +} diff --git a/common/strmatcher/mph_matcher_compact.go b/common/strmatcher/mph_matcher_compact.go new file mode 100644 index 00000000..a40b9f56 --- /dev/null +++ b/common/strmatcher/mph_matcher_compact.go @@ -0,0 +1,47 @@ +package strmatcher + +import ( + "bytes" + "encoding/gob" + "io" +) + +func init() { + gob.Register(&RegexMatcher{}) + gob.Register(fullMatcher("")) + gob.Register(substrMatcher("")) + gob.Register(domainMatcher("")) +} + +func (g *MphMatcherGroup) Serialize(w io.Writer) error { + data := MphMatcherGroup{ + Ac: g.Ac, + OtherMatchers: g.OtherMatchers, + Rules: g.Rules, + Level0: g.Level0, + Level0Mask: g.Level0Mask, + Level1: g.Level1, + Level1Mask: g.Level1Mask, + Count: g.Count, + } + return gob.NewEncoder(w).Encode(data) +} + +func NewMphMatcherGroupFromBuffer(data []byte) (*MphMatcherGroup, error) { + var gData MphMatcherGroup + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(&gData); err != nil { + return nil, err + } + + g := NewMphMatcherGroup() + g.Ac = gData.Ac + g.OtherMatchers = gData.OtherMatchers + g.Rules = gData.Rules + g.Level0 = gData.Level0 + g.Level0Mask = gData.Level0Mask + g.Level1 = gData.Level1 + g.Level1Mask = gData.Level1Mask + g.Count = gData.Count + + return g, nil +} diff --git a/common/strmatcher/strmatcher.go b/common/strmatcher/strmatcher.go index 4035acc3..89e7dae6 100644 --- a/common/strmatcher/strmatcher.go +++ b/common/strmatcher/strmatcher.go @@ -41,8 +41,9 @@ func (t Type) New(pattern string) (Matcher, error) { if err != nil { return nil, err } - return ®exMatcher{ - pattern: r, + return &RegexMatcher{ + Pattern: pattern, + reg: r, }, nil default: return nil, errors.New("unk type") @@ -53,11 +54,13 @@ func (t Type) New(pattern string) (Matcher, error) { type IndexMatcher interface { // Match returns the index of a matcher that matches the input. It returns empty array if no such matcher exists. Match(input string) []uint32 + // Size returns the number of matchers in the group. + Size() uint32 } -type matcherEntry struct { - m Matcher - id uint32 +type MatcherEntry struct { + M Matcher + Id uint32 } // MatcherGroup is an implementation of IndexMatcher. @@ -66,7 +69,7 @@ type MatcherGroup struct { count uint32 fullMatcher FullMatcherGroup domainMatcher DomainMatcherGroup - otherMatchers []matcherEntry + otherMatchers []MatcherEntry } // Add adds a new Matcher into the MatcherGroup, and returns its index. The index will never be 0. @@ -80,9 +83,9 @@ func (g *MatcherGroup) Add(m Matcher) uint32 { case domainMatcher: g.domainMatcher.addMatcher(tm, c) default: - g.otherMatchers = append(g.otherMatchers, matcherEntry{ - m: m, - id: c, + g.otherMatchers = append(g.otherMatchers, MatcherEntry{ + M: m, + Id: c, }) } @@ -95,8 +98,8 @@ func (g *MatcherGroup) Match(pattern string) []uint32 { result = append(result, g.fullMatcher.Match(pattern)...) result = append(result, g.domainMatcher.Match(pattern)...) for _, e := range g.otherMatchers { - if e.m.Match(pattern) { - result = append(result, e.id) + if e.M.Match(pattern) { + result = append(result, e.Id) } } return result @@ -106,3 +109,33 @@ func (g *MatcherGroup) Match(pattern string) []uint32 { func (g *MatcherGroup) Size() uint32 { return g.count } + +type IndexMatcherGroup struct { + Matchers []IndexMatcher +} + +func (g *IndexMatcherGroup) Match(input string) []uint32 { + var offset uint32 + for _, m := range g.Matchers { + if res := m.Match(input); len(res) > 0 { + if offset == 0 { + return res + } + shifted := make([]uint32, len(res)) + for i, id := range res { + shifted[i] = id + offset + } + return shifted + } + offset += m.Size() + } + return nil +} + +func (g *IndexMatcherGroup) Size() uint32 { + var count uint32 + for _, m := range g.Matchers { + count += m.Size() + } + return count +} diff --git a/common/utils/access_field.go b/common/utils/access_field.go new file mode 100644 index 00000000..bc42e67c --- /dev/null +++ b/common/utils/access_field.go @@ -0,0 +1,17 @@ +package utils + +import ( + "reflect" + "unsafe" +) + +// AccessField can used to access unexported field of a struct +// valueType must be the exact type of the field or it will panic +func AccessField[valueType any](obj any, fieldName string) *valueType { + field := reflect.ValueOf(obj).Elem().FieldByName(fieldName) + if field.Type() != reflect.TypeOf(*new(valueType)) { + panic("field type: " + field.Type().String() + ", valueType: " + reflect.TypeOf(*new(valueType)).String()) + } + v := (*valueType)(unsafe.Pointer(field.UnsafeAddr())) + return v +} diff --git a/common/utils/browser.go b/common/utils/browser.go new file mode 100644 index 00000000..91209f4b --- /dev/null +++ b/common/utils/browser.go @@ -0,0 +1,28 @@ +package utils + +import ( + "math/rand" + "strconv" + "time" + + "github.com/klauspost/cpuid/v2" +) + +func ChromeVersion() int { + // Use only CPU info as seed for PRNG + seed := int64(cpuid.CPU.Family + cpuid.CPU.Model + cpuid.CPU.PhysicalCores + cpuid.CPU.LogicalCores + cpuid.CPU.CacheLine) + rng := rand.New(rand.NewSource(seed)) + // Start from Chrome 144 released on 2026.1.13 + releaseDate := time.Date(2026, 1, 13, 0, 0, 0, 0, time.UTC) + version := 144 + now := time.Now() + // Each version has random 25-45 day interval + for releaseDate.Before(now) { + releaseDate = releaseDate.AddDate(0, 0, rng.Intn(21)+25) + version++ + } + return version - 1 +} + +// ChromeUA provides default browser User-Agent based on CPU-seeded PRNG. +var ChromeUA = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/" + strconv.Itoa(ChromeVersion()) + ".0.0.0 Safari/537.36" diff --git a/common/utils/padding.go b/common/utils/padding.go new file mode 100644 index 00000000..fe95ba9a --- /dev/null +++ b/common/utils/padding.go @@ -0,0 +1,24 @@ +package utils + +import ( + "math/rand/v2" +) + +var ( + // 8 รท (397/62) + h2packCorrectionFactor = 1.2493702770780857 + base62TotalCharsNum = 62 + base62Chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +) + +// H2Base62Pad generates a base62 padding string for HTTP/2 header +// The total len will be slightly longer than the input to match the length after h2(h3 also) header huffman encoding +func H2Base62Pad[T int32 | int64 | int](expectedLen T) string { + actualLenFloat := float64(expectedLen) * h2packCorrectionFactor + actualLen := int(actualLenFloat) + result := make([]byte, actualLen) + for i := range actualLen { + result[i] = base62Chars[rand.N(base62TotalCharsNum)] + } + return string(result) +} diff --git a/core/core.go b/core/core.go index 3cbd2bc9..c244a94a 100644 --- a/core/core.go +++ b/core/core.go @@ -18,8 +18,8 @@ import ( var ( Version_x byte = 26 - Version_y byte = 1 - Version_z byte = 23 + Version_y byte = 2 + Version_z byte = 6 ) var ( diff --git a/go.mod b/go.mod index 8117d8ed..84d03b84 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/amnezia-vpn/amnezia-xray-core -go 1.25.6 +go 1.25.7 require ( github.com/apernet/quic-go v0.57.2-0.20260111184307-eec823306178 @@ -9,6 +9,7 @@ require ( github.com/golang/mock v1.7.0-rc.1 github.com/google/go-cmp v0.7.0 github.com/gorilla/websocket v1.5.3 + github.com/klauspost/cpuid/v2 v2.0.12 github.com/miekg/dns v1.1.72 github.com/pelletier/go-toml v1.9.5 github.com/pires/go-proxyproto v0.9.2 @@ -39,7 +40,6 @@ require ( github.com/google/btree v1.1.2 // indirect github.com/juju/ratelimit v1.0.2 // indirect github.com/klauspost/compress v1.17.4 // indirect - github.com/klauspost/cpuid/v2 v2.0.12 // indirect github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect diff --git a/infra/conf/router.go b/infra/conf/router.go index 455ee781..1aa43749 100644 --- a/infra/conf/router.go +++ b/infra/conf/router.go @@ -12,6 +12,7 @@ import ( "github.com/amnezia-vpn/amnezia-xray-core/app/router" "github.com/amnezia-vpn/amnezia-xray-core/common/errors" "github.com/amnezia-vpn/amnezia-xray-core/common/net" + "github.com/amnezia-vpn/amnezia-xray-core/common/platform" "github.com/amnezia-vpn/amnezia-xray-core/common/platform/filesystem" "github.com/amnezia-vpn/amnezia-xray-core/common/serial" "google.golang.org/protobuf/proto" @@ -204,6 +205,13 @@ func loadIP(file, code string) ([]*router.CIDR, error) { } func loadSite(file, code string) ([]*router.Domain, error) { + + // Check if domain matcher cache is provided via environment + domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" }) + if domainMatcherPath != "" { + return []*router.Domain{{}}, nil + } + bs, err := loadFile(file, code) if err != nil { return nil, err diff --git a/infra/conf/transport_authenticators.go b/infra/conf/transport_authenticators.go index a209ab6e..65b17ec0 100644 --- a/infra/conf/transport_authenticators.go +++ b/infra/conf/transport_authenticators.go @@ -4,72 +4,18 @@ import ( "sort" "github.com/amnezia-vpn/amnezia-xray-core/common/errors" - "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/dns" + "github.com/amnezia-vpn/amnezia-xray-core/common/utils" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/http" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/noop" - "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/srtp" - "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/tls" - "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/utp" - "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/wechat" - "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/wireguard" "google.golang.org/protobuf/proto" ) -type NoOpAuthenticator struct{} - -func (NoOpAuthenticator) Build() (proto.Message, error) { - return new(noop.Config), nil -} - type NoOpConnectionAuthenticator struct{} func (NoOpConnectionAuthenticator) Build() (proto.Message, error) { return new(noop.ConnectionConfig), nil } -type SRTPAuthenticator struct{} - -func (SRTPAuthenticator) Build() (proto.Message, error) { - return new(srtp.Config), nil -} - -type UTPAuthenticator struct{} - -func (UTPAuthenticator) Build() (proto.Message, error) { - return new(utp.Config), nil -} - -type WechatVideoAuthenticator struct{} - -func (WechatVideoAuthenticator) Build() (proto.Message, error) { - return new(wechat.VideoConfig), nil -} - -type WireguardAuthenticator struct{} - -func (WireguardAuthenticator) Build() (proto.Message, error) { - return new(wireguard.WireguardConfig), nil -} - -type DNSAuthenticator struct { - Domain string `json:"domain"` -} - -func (v *DNSAuthenticator) Build() (proto.Message, error) { - config := new(dns.Config) - config.Domain = "www.baidu.com" - if len(v.Domain) > 0 { - config.Domain = v.Domain - } - return config, nil -} - -type DTLSAuthenticator struct{} - -func (DTLSAuthenticator) Build() (proto.Message, error) { - return new(tls.PacketConfig), nil -} - type AuthenticatorRequest struct { Version string `json:"version"` Method string `json:"method"` @@ -95,11 +41,8 @@ func (v *AuthenticatorRequest) Build() (*http.RequestConfig, error) { Value: []string{"www.baidu.com", "www.bing.com"}, }, { - Name: "User-Agent", - Value: []string{ - "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/53.0.2785.143 Safari/537.36", - "Mozilla/5.0 (iPhone; CPU iPhone OS 10_0_2 like Mac OS X) AppleWebKit/601.1 (KHTML, like Gecko) CriOS/53.0.2785.109 Mobile/14A456 Safari/601.1.46", - }, + Name: "User-Agent", + Value: []string{utils.ChromeUA}, }, { Name: "Accept-Encoding", diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index 17ffc4fb..1629b265 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -1,6 +1,7 @@ package conf import ( + "context" "encoding/base64" "encoding/hex" "encoding/json" @@ -10,13 +11,24 @@ import ( "strconv" "strings" "syscall" + "time" "github.com/amnezia-vpn/amnezia-xray-core/common/errors" "github.com/amnezia-vpn/amnezia-xray-core/common/net" "github.com/amnezia-vpn/amnezia-xray-core/common/platform/filesystem" "github.com/amnezia-vpn/amnezia-xray-core/common/serial" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/dns" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/dtls" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/srtp" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/utp" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/wechat" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/wireguard" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/aes128gcm" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/original" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/salamander" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/xdns" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/xicmp" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/httpupgrade" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/hysteria" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/kcp" @@ -29,16 +41,6 @@ import ( ) var ( - kcpHeaderLoader = NewJSONConfigLoader(ConfigCreatorCache{ - "none": func() interface{} { return new(NoOpAuthenticator) }, - "srtp": func() interface{} { return new(SRTPAuthenticator) }, - "utp": func() interface{} { return new(UTPAuthenticator) }, - "wechat-video": func() interface{} { return new(WechatVideoAuthenticator) }, - "dtls": func() interface{} { return new(DTLSAuthenticator) }, - "wireguard": func() interface{} { return new(WireguardAuthenticator) }, - "dns": func() interface{} { return new(DNSAuthenticator) }, - }, "type", "") - tcpHeaderLoader = NewJSONConfigLoader(ConfigCreatorCache{ "none": func() interface{} { return new(NoOpConnectionAuthenticator) }, "http": func() interface{} { return new(Authenticator) }, @@ -63,9 +65,9 @@ func (c *KCPConfig) Build() (proto.Message, error) { if c.Mtu != nil { mtu := *c.Mtu - if mtu < 576 || mtu > 1460 { - return nil, errors.New("invalid mKCP MTU size: ", mtu).AtError() - } + // if mtu < 576 || mtu > 1460 { + // return nil, errors.New("invalid mKCP MTU size: ", mtu).AtError() + // } config.Mtu = &kcp.MTU{Value: mtu} } if c.Tti != nil { @@ -100,20 +102,8 @@ func (c *KCPConfig) Build() (proto.Message, error) { config.WriteBuffer = &kcp.WriteBuffer{Size: 512 * 1024} } } - if len(c.HeaderConfig) > 0 { - headerConfig, _, err := kcpHeaderLoader.Load(c.HeaderConfig) - if err != nil { - return nil, errors.New("invalid mKCP header config.").Base(err).AtError() - } - ts, err := headerConfig.(Buildable).Build() - if err != nil { - return nil, errors.New("invalid mKCP header config").Base(err).AtError() - } - config.HeaderConfig = serial.ToTypedMessage(ts) - } - - if c.Seed != nil { - config.Seed = &kcp.EncryptionSeed{Seed: *c.Seed} + if c.HeaderConfig != nil || c.Seed != nil { + return nil, errors.PrintRemovedFeatureError("mkcp header & seed", "finalmask/udp header-* & mkcp-original & mkcp-aes128gcm") } return config, nil @@ -228,6 +218,19 @@ type SplitHTTPConfig struct { Mode string `json:"mode"` Headers map[string]string `json:"headers"` XPaddingBytes Int32Range `json:"xPaddingBytes"` + XPaddingObfsMode bool `json:"xPaddingObfsMode"` + XPaddingKey string `json:"xPaddingKey"` + XPaddingHeader string `json:"xPaddingHeader"` + XPaddingPlacement string `json:"xPaddingPlacement"` + XPaddingMethod string `json:"xPaddingMethod"` + UplinkHTTPMethod string `json:"uplinkHTTPMethod"` + SessionPlacement string `json:"sessionPlacement"` + SessionKey string `json:"sessionKey"` + SeqPlacement string `json:"seqPlacement"` + SeqKey string `json:"seqKey"` + UplinkDataPlacement string `json:"uplinkDataPlacement"` + UplinkDataKey string `json:"uplinkDataKey"` + UplinkChunkSize uint32 `json:"uplinkChunkSize"` NoGRPCHeader bool `json:"noGRPCHeader"` NoSSEHeader bool `json:"noSSEHeader"` ScMaxEachPostBytes Int32Range `json:"scMaxEachPostBytes"` @@ -287,6 +290,108 @@ func (c *SplitHTTPConfig) Build() (proto.Message, error) { return nil, errors.New("xPaddingBytes cannot be disabled") } + if c.XPaddingKey == "" { + c.XPaddingKey = "x_padding" + } + + if c.XPaddingHeader == "" { + c.XPaddingHeader = "X-Padding" + } + + switch c.XPaddingPlacement { + case "": + c.XPaddingPlacement = "queryInHeader" + case "cookie", "header", "query", "queryInHeader": + default: + return nil, errors.New("unsupported padding placement: " + c.XPaddingPlacement) + } + + switch c.XPaddingMethod { + case "": + c.XPaddingMethod = "repeat-x" + case "repeat-x", "tokenish": + default: + return nil, errors.New("unsupported padding method: " + c.XPaddingMethod) + } + + switch c.UplinkDataPlacement { + case "": + c.UplinkDataPlacement = "body" + case "body": + case "cookie", "header": + if c.Mode != "packet-up" { + return nil, errors.New("UplinkDataPlacement can be " + c.UplinkDataPlacement + " only in packet-up mode") + } + default: + return nil, errors.New("unsupported uplink data placement: " + c.UplinkDataPlacement) + } + + if c.UplinkHTTPMethod == "" { + c.UplinkHTTPMethod = "POST" + } + c.UplinkHTTPMethod = strings.ToUpper(c.UplinkHTTPMethod) + + if c.UplinkHTTPMethod == "GET" && c.Mode != "packet-up" { + return nil, errors.New("uplinkHTTPMethod can be GET only in packet-up mode") + } + + switch c.SessionPlacement { + case "": + c.SessionPlacement = "path" + case "path", "cookie", "header", "query": + default: + return nil, errors.New("unsupported session placement: " + c.SessionPlacement) + } + + switch c.SeqPlacement { + case "": + c.SeqPlacement = "path" + case "path", "cookie", "header", "query": + if c.SessionPlacement == "path" { + return nil, errors.New("SeqPlacement must be path when SessionPlacement is path") + } + default: + return nil, errors.New("unsupported seq placement: " + c.SeqPlacement) + } + + if c.SessionPlacement != "path" && c.SessionKey == "" { + switch c.SessionPlacement { + case "cookie", "query": + c.SessionKey = "x_session" + case "header": + c.SessionKey = "X-Session" + } + } + + if c.SeqPlacement != "path" && c.SeqKey == "" { + switch c.SeqPlacement { + case "cookie", "query": + c.SeqKey = "x_seq" + case "header": + c.SeqKey = "X-Seq" + } + } + + if c.UplinkDataPlacement != "body" && c.UplinkDataKey == "" { + switch c.UplinkDataPlacement { + case "cookie": + c.UplinkDataKey = "x_data" + case "header": + c.UplinkDataKey = "X-Data" + } + } + + if c.UplinkChunkSize == 0 { + switch c.UplinkDataPlacement { + case "cookie": + c.UplinkChunkSize = 3 * 1024 // 3KB + case "header": + c.UplinkChunkSize = 4 * 1024 // 4KB + } + } else if c.UplinkChunkSize < 64 { + c.UplinkChunkSize = 64 + } + if c.Xmux.MaxConnections.To > 0 && c.Xmux.MaxConcurrency.To > 0 { return nil, errors.New("maxConnections cannot be specified together with maxConcurrency") } @@ -305,6 +410,19 @@ func (c *SplitHTTPConfig) Build() (proto.Message, error) { Mode: c.Mode, Headers: c.Headers, XPaddingBytes: newRangeConfig(c.XPaddingBytes), + XPaddingObfsMode: c.XPaddingObfsMode, + XPaddingKey: c.XPaddingKey, + XPaddingHeader: c.XPaddingHeader, + XPaddingPlacement: c.XPaddingPlacement, + XPaddingMethod: c.XPaddingMethod, + UplinkHTTPMethod: c.UplinkHTTPMethod, + SessionPlacement: c.SessionPlacement, + SeqPlacement: c.SeqPlacement, + SessionKey: c.SessionKey, + SeqKey: c.SeqKey, + UplinkDataPlacement: c.UplinkDataPlacement, + UplinkDataKey: c.UplinkDataKey, + UplinkChunkSize: c.UplinkChunkSize, NoGRPCHeader: c.NoGRPCHeader, NoSSEHeader: c.NoSSEHeader, ScMaxEachPostBytes: newRangeConfig(c.ScMaxEachPostBytes), @@ -631,7 +749,12 @@ func (c *TLSConfig) Build() (proto.Message, error) { config.MasterKeyLog = c.MasterKeyLog if c.AllowInsecure { - return nil, errors.PrintRemovedFeatureError(`"allowInsecure"`, `"pinnedPeerCertSha256"`) + if time.Now().After(time.Date(2026, 6, 1, 0, 0, 0, 0, time.UTC)) { + return nil, errors.PrintRemovedFeatureError(`"allowInsecure"`, `"pinnedPeerCertSha256"`) + } else { + errors.LogWarning(context.Background(), `"allowInsecure" will be removed automatically after 2026-06-01, please use "pinnedPeerCertSha256"(pcs) and "verifyPeerCertByName"(vcn) instead, PLEASE CONTACT YOUR SERVICE PROVIDER (AIRPORT)`) + config.AllowInsecure = true + } } if c.PinnedPeerCertSha256 != "" { for v := range strings.SplitSeq(c.PinnedPeerCertSha256, ",") { @@ -639,10 +762,14 @@ func (c *TLSConfig) Build() (proto.Message, error) { if v == "" { continue } - hashValue, err := hex.DecodeString(v) + // remove colons for OpenSSL format + hashValue, err := hex.DecodeString(strings.ReplaceAll(v, ":", "")) if err != nil { return nil, err } + if len(hashValue) != 32 { + return nil, errors.New("incorrect pinnedPeerCertSha256 length: ", v) + } config.PinnedPeerCertSha256 = append(config.PinnedPeerCertSha256, hashValue) } } @@ -1111,10 +1238,81 @@ func (c *SocketConfig) Build() (*internet.SocketConfig, error) { var ( udpmaskLoader = NewJSONConfigLoader(ConfigCreatorCache{ - "salamander": func() interface{} { return new(Salamander) }, + "header-dns": func() interface{} { return new(Dns) }, + "header-dtls": func() interface{} { return new(Dtls) }, + "header-srtp": func() interface{} { return new(Srtp) }, + "header-utp": func() interface{} { return new(Utp) }, + "header-wechat": func() interface{} { return new(Wechat) }, + "header-wireguard": func() interface{} { return new(Wireguard) }, + "mkcp-original": func() interface{} { return new(Original) }, + "mkcp-aes128gcm": func() interface{} { return new(Aes128Gcm) }, + "salamander": func() interface{} { return new(Salamander) }, + "xdns": func() interface{} { return new(Xdns) }, + "xicmp": func() interface{} { return new(Xicmp) }, }, "type", "settings") ) +type Dns struct { + Domain string `json:"domain"` +} + +func (c *Dns) Build() (proto.Message, error) { + config := &dns.Config{} + config.Domain = "www.baidu.com" + + if len(c.Domain) > 0 { + config.Domain = c.Domain + } + + return config, nil +} + +type Dtls struct{} + +func (c *Dtls) Build() (proto.Message, error) { + return &dtls.Config{}, nil +} + +type Srtp struct{} + +func (c *Srtp) Build() (proto.Message, error) { + return &srtp.Config{}, nil +} + +type Utp struct{} + +func (c *Utp) Build() (proto.Message, error) { + return &utp.Config{}, nil +} + +type Wechat struct{} + +func (c *Wechat) Build() (proto.Message, error) { + return &wechat.Config{}, nil +} + +type Wireguard struct{} + +func (c *Wireguard) Build() (proto.Message, error) { + return &wireguard.Config{}, nil +} + +type Original struct{} + +func (c *Original) Build() (proto.Message, error) { + return &original.Config{}, nil +} + +type Aes128Gcm struct { + Password string `json:"password"` +} + +func (c *Aes128Gcm) Build() (proto.Message, error) { + return &aes128gcm.Config{ + Password: c.Password, + }, nil +} + type Salamander struct { Password string `json:"password"` } @@ -1125,14 +1323,46 @@ func (c *Salamander) Build() (proto.Message, error) { return config, nil } -type FinalMask struct { +type Xdns struct { + Domain string `json:"domain"` +} + +func (c *Xdns) Build() (proto.Message, error) { + if c.Domain == "" { + return nil, errors.New("empty domain") + } + + return &xdns.Config{ + Domain: c.Domain, + }, nil +} + +type Xicmp struct { + ListenIp string `json:"listenIp"` + Id uint16 `json:"id"` +} + +func (c *Xicmp) Build() (proto.Message, error) { + config := &xicmp.Config{ + Ip: c.ListenIp, + Id: int32(c.Id), + } + + if config.Ip == "" { + config.Ip = "0.0.0.0" + } + + return config, nil +} + +type Mask struct { Type string `json:"type"` Settings *json.RawMessage `json:"settings"` } -func (c *FinalMask) Build(tcpmaskLoader bool) (proto.Message, error) { +func (c *Mask) Build(tcp bool) (proto.Message, error) { loader := udpmaskLoader - if tcpmaskLoader { + if tcp { return nil, errors.New("") } @@ -1151,12 +1381,17 @@ func (c *FinalMask) Build(tcpmaskLoader bool) (proto.Message, error) { return ts, nil } +type FinalMask struct { + Tcp []Mask `json:"tcp"` + Udp []Mask `json:"udp"` +} + type StreamConfig struct { Address *Address `json:"address"` Port uint16 `json:"port"` Network *TransportProtocol `json:"network"` Security string `json:"security"` - Udpmasks []*FinalMask `json:"udpmasks"` + FinalMask *FinalMask `json:"finalmask"` TLSSettings *TLSConfig `json:"tlsSettings"` REALITYSettings *REALITYConfig `json:"realitySettings"` RAWSettings *TCPConfig `json:"rawSettings"` @@ -1306,12 +1541,21 @@ func (c *StreamConfig) Build() (*internet.StreamConfig, error) { config.SocketSettings = ss } - for _, mask := range c.Udpmasks { - u, err := mask.Build(false) - if err != nil { - return nil, errors.New("failed to build mask with type ", mask.Type).Base(err) + if c.FinalMask != nil { + for _, mask := range c.FinalMask.Tcp { + u, err := mask.Build(true) + if err != nil { + return nil, errors.New("failed to build mask with type ", mask.Type).Base(err) + } + config.Tcpmasks = append(config.Tcpmasks, serial.ToTypedMessage(u)) + } + for _, mask := range c.FinalMask.Udp { + u, err := mask.Build(false) + if err != nil { + return nil, errors.New("failed to build mask with type ", mask.Type).Base(err) + } + config.Udpmasks = append(config.Udpmasks, serial.ToTypedMessage(u)) } - config.Udpmasks = append(config.Udpmasks, serial.ToTypedMessage(u)) } return config, nil diff --git a/infra/conf/xray.go b/infra/conf/xray.go index eec2a13d..e4e58ce8 100644 --- a/infra/conf/xray.go +++ b/infra/conf/xray.go @@ -1,16 +1,21 @@ package conf import ( + "bytes" "context" "encoding/json" + "os" "path/filepath" + "sort" "strings" "github.com/amnezia-vpn/amnezia-xray-core/app/dispatcher" "github.com/amnezia-vpn/amnezia-xray-core/app/proxyman" + "github.com/amnezia-vpn/amnezia-xray-core/app/router" "github.com/amnezia-vpn/amnezia-xray-core/app/stats" "github.com/amnezia-vpn/amnezia-xray-core/common/errors" "github.com/amnezia-vpn/amnezia-xray-core/common/net" + "github.com/amnezia-vpn/amnezia-xray-core/common/platform" "github.com/amnezia-vpn/amnezia-xray-core/common/serial" core "github.com/amnezia-vpn/amnezia-xray-core/core" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet" @@ -607,6 +612,187 @@ func (c *Config) Build() (*core.Config, error) { return config, nil } +func (c *Config) BuildMPHCache(customMatcherFilePath *string) error { + var geosite []*router.GeoSite + deps := make(map[string][]string) + uniqueGeosites := make(map[string]bool) + uniqueTags := make(map[string]bool) + matcherFilePath := platform.GetAssetLocation("matcher.cache") + + if customMatcherFilePath != nil { + matcherFilePath = *customMatcherFilePath + } + + processGeosite := func(dStr string) bool { + prefix := "" + if strings.HasPrefix(dStr, "geosite:") { + prefix = "geosite:" + } else if strings.HasPrefix(dStr, "ext-domain:") { + prefix = "ext-domain:" + } + if prefix == "" { + return false + } + key := strings.ToLower(dStr) + country := strings.ToUpper(dStr[len(prefix):]) + if !uniqueGeosites[country] { + ds, err := loadGeositeWithAttr("geosite.dat", country) + if err == nil { + uniqueGeosites[country] = true + geosite = append(geosite, &router.GeoSite{CountryCode: key, Domain: ds}) + } + } + return true + } + + processDomains := func(tag string, rawDomains []string) { + var manualDomains []*router.Domain + var dDeps []string + for _, dStr := range rawDomains { + if processGeosite(dStr) { + dDeps = append(dDeps, strings.ToLower(dStr)) + } else { + ds, err := parseDomainRule(dStr) + if err == nil { + manualDomains = append(manualDomains, ds...) + } + } + } + if len(manualDomains) > 0 { + if !uniqueTags[tag] { + uniqueTags[tag] = true + geosite = append(geosite, &router.GeoSite{CountryCode: tag, Domain: manualDomains}) + } + } + if len(dDeps) > 0 { + deps[tag] = append(deps[tag], dDeps...) + } + } + + // proccess rules + if c.RouterConfig != nil { + for _, rawRule := range c.RouterConfig.RuleList { + type SimpleRule struct { + RuleTag string `json:"ruleTag"` + Domain *StringList `json:"domain"` + Domains *StringList `json:"domains"` + } + var sr SimpleRule + json.Unmarshal(rawRule, &sr) + if sr.RuleTag == "" { + continue + } + var allDomains []string + if sr.Domain != nil { + allDomains = append(allDomains, *sr.Domain...) + } + if sr.Domains != nil { + allDomains = append(allDomains, *sr.Domains...) + } + processDomains(sr.RuleTag, allDomains) + } + } + + // proccess dns servers + if c.DNSConfig != nil { + for _, ns := range c.DNSConfig.Servers { + if ns.Tag == "" { + continue + } + processDomains(ns.Tag, ns.Domains) + } + } + + var hostIPs map[string][]string + if c.DNSConfig != nil && c.DNSConfig.Hosts != nil { + hostIPs = make(map[string][]string) + var hostDeps []string + var hostPatterns []string + + // use raw map to avoid expanding geosites + var domains []string + for domain := range c.DNSConfig.Hosts.Hosts { + domains = append(domains, domain) + } + sort.Strings(domains) + + manualHostGroups := make(map[string][]*router.Domain) + manualHostIPs := make(map[string][]string) + manualHostNames := make(map[string]string) + + for _, domain := range domains { + ha := c.DNSConfig.Hosts.Hosts[domain] + m := getHostMapping(ha) + + var ips []string + if m.ProxiedDomain != "" { + ips = append(ips, m.ProxiedDomain) + } else { + for _, ip := range m.Ip { + ips = append(ips, net.IPAddress(ip).String()) + } + } + + if processGeosite(domain) { + tag := strings.ToLower(domain) + hostDeps = append(hostDeps, tag) + hostIPs[tag] = ips + hostPatterns = append(hostPatterns, domain) + } else { + // build manual domains by their destination IPs + sort.Strings(ips) + ipKey := strings.Join(ips, ",") + ds, err := parseDomainRule(domain) + if err == nil { + manualHostGroups[ipKey] = append(manualHostGroups[ipKey], ds...) + manualHostIPs[ipKey] = ips + if _, ok := manualHostNames[ipKey]; !ok { + manualHostNames[ipKey] = domain + } + } + } + } + + // create manual host groups + var ipKeys []string + for k := range manualHostGroups { + ipKeys = append(ipKeys, k) + } + sort.Strings(ipKeys) + + for _, k := range ipKeys { + tag := manualHostNames[k] + geosite = append(geosite, &router.GeoSite{CountryCode: tag, Domain: manualHostGroups[k]}) + hostDeps = append(hostDeps, tag) + hostIPs[tag] = manualHostIPs[k] + + // record tag _ORDER links the matcher to IP addresses + hostPatterns = append(hostPatterns, tag) + } + + deps["HOSTS"] = hostDeps + hostIPs["_ORDER"] = hostPatterns + } + + f, err := os.Create(matcherFilePath) + if err != nil { + return err + } + defer f.Close() + + var buf bytes.Buffer + + if err := router.SerializeGeoSiteList(geosite, deps, hostIPs, &buf); err != nil { + return err + } + + if _, err := f.Write(buf.Bytes()); err != nil { + return err + } + + return nil +} + // Convert string to Address. func ParseSendThough(Addr *string) *Address { var addr Address diff --git a/main/commands/all/buildmphcache.go b/main/commands/all/buildmphcache.go new file mode 100644 index 00000000..ba9cc764 --- /dev/null +++ b/main/commands/all/buildmphcache.go @@ -0,0 +1,52 @@ +package all + +import ( + "os" + + "github.com/amnezia-vpn/amnezia-xray-core/common/platform" + "github.com/amnezia-vpn/amnezia-xray-core/infra/conf/serial" + "github.com/amnezia-vpn/amnezia-xray-core/main/commands/base" +) + +var cmdBuildMphCache = &base.Command{ + UsageLine: `{{.Exec}} buildMphCache [-c config.json] [-o domain.cache]`, + Short: `Build domain matcher cache`, + Long: ` +Build domain matcher cache from a configuration file. + +Example: {{.Exec}} buildMphCache -c config.json -o domain.cache +`, +} + +func init() { + cmdBuildMphCache.Run = executeBuildMphCache +} + +var ( + configPath = cmdBuildMphCache.Flag.String("c", "config.json", "Config file path") + outputPath = cmdBuildMphCache.Flag.String("o", "domain.cache", "Output cache file path") +) + +func executeBuildMphCache(cmd *base.Command, args []string) { + cf, err := os.Open(*configPath) + if err != nil { + base.Fatalf("failed to open config file: %v", err) + } + defer cf.Close() + + // prevent using existing cache + domainMatcherPath := platform.NewEnvFlag(platform.MphCachePath).GetValue(func() string { return "" }) + if domainMatcherPath != "" { + os.Setenv("XRAY_MPH_CACHE", "") + defer os.Setenv("XRAY_MPH_CACHE", domainMatcherPath) + } + + config, err := serial.DecodeJSONConfig(cf) + if err != nil { + base.Fatalf("failed to decode config file: %v", err) + } + + if err := config.BuildMPHCache(outputPath); err != nil { + base.Fatalf("failed to build MPH cache: %v", err) + } +} diff --git a/main/commands/all/commands.go b/main/commands/all/commands.go index 775ff360..dc42729e 100644 --- a/main/commands/all/commands.go +++ b/main/commands/all/commands.go @@ -19,5 +19,6 @@ func init() { cmdMLDSA65, cmdMLKEM768, cmdVLESSEnc, + cmdBuildMphCache, ) } diff --git a/main/commands/all/tls/hash.go b/main/commands/all/tls/hash.go new file mode 100644 index 00000000..6507ff98 --- /dev/null +++ b/main/commands/all/tls/hash.go @@ -0,0 +1,78 @@ +package tls + +import ( + "bytes" + "crypto/x509" + "encoding/pem" + "flag" + "fmt" + "os" + "text/tabwriter" + + "github.com/amnezia-vpn/amnezia-xray-core/main/commands/base" + . "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/tls" +) + +var cmdHash = &base.Command{ + UsageLine: "{{.Exec}} tls hash", + Short: "Calculate TLS certificate hash.", + Long: ` + xray tls hash --cert + Calculate TLS certificate hash. + `, +} + +func init() { + cmdHash.Run = executeHash // break init loop +} + +var input = cmdHash.Flag.String("cert", "fullchain.pem", "The file path of the certificate") + +func executeHash(cmd *base.Command, args []string) { + fs := flag.NewFlagSet("hash", flag.ContinueOnError) + if err := fs.Parse(args); err != nil { + fmt.Println(err) + return + } + certContent, err := os.ReadFile(*input) + if err != nil { + fmt.Println(err) + return + } + var certs []*x509.Certificate + if bytes.Contains(certContent, []byte("BEGIN")) { + for { + block, remain := pem.Decode(certContent) + if block == nil { + break + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + fmt.Println("Unable to decode certificate:", err) + return + } + certs = append(certs, cert) + certContent = remain + } + } else { + certs, err = x509.ParseCertificates(certContent) + if err != nil { + fmt.Println("Unable to parse certificates:", err) + return + } + } + if len(certs) == 0 { + fmt.Println("No certificates found") + return + } + tabWriter := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + for i, cert := range certs { + hash := GenerateCertHashHex(cert) + if i == 0 { + fmt.Fprintf(tabWriter, "Leaf SHA256:\t%s\n", hash) + } else { + fmt.Fprintf(tabWriter, "CA <%s> SHA256:\t%s\n", cert.Subject.CommonName, hash) + } + } + tabWriter.Flush() +} diff --git a/main/commands/all/tls/leafcerthash.go b/main/commands/all/tls/leafcerthash.go deleted file mode 100644 index 1bba08af..00000000 --- a/main/commands/all/tls/leafcerthash.go +++ /dev/null @@ -1,44 +0,0 @@ -package tls - -import ( - "flag" - "fmt" - "os" - - "github.com/amnezia-vpn/amnezia-xray-core/main/commands/base" - "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/tls" -) - -var cmdLeafCertHash = &base.Command{ - UsageLine: "{{.Exec}} tls leafCertHash", - Short: "Calculate TLS leaf certificate hash.", - Long: ` - xray tls leafCertHash --cert - Calculate TLS leaf certificate hash. - `, -} - -func init() { - cmdLeafCertHash.Run = executeLeafCertHash // break init loop -} - -var input = cmdLeafCertHash.Flag.String("cert", "fullchain.pem", "The file path of the leaf certificate") - -func executeLeafCertHash(cmd *base.Command, args []string) { - fs := flag.NewFlagSet("leafCertHash", flag.ContinueOnError) - if err := fs.Parse(args); err != nil { - fmt.Println(err) - return - } - certContent, err := os.ReadFile(*input) - if err != nil { - fmt.Println(err) - return - } - certChainHashB64, err := tls.CalculatePEMLeafCertSHA256Hash(certContent) - if err != nil { - fmt.Println("failed to decode cert", err) - return - } - fmt.Println(certChainHashB64) -} diff --git a/main/commands/all/tls/ping.go b/main/commands/all/tls/ping.go index c4290849..0e8ed09e 100644 --- a/main/commands/all/tls/ping.go +++ b/main/commands/all/tls/ping.go @@ -6,8 +6,13 @@ import ( "encoding/hex" "fmt" "net" + "os" "strconv" + "text/tabwriter" + utls "github.com/refraction-networking/utls" + + "github.com/amnezia-vpn/amnezia-xray-core/common/utils" "github.com/amnezia-vpn/amnezia-xray-core/main/commands/base" . "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/tls" ) @@ -46,6 +51,7 @@ func executePing(cmd *base.Command, args []string) { } else { TargetPort, _ = strconv.Atoi(port) } + tabWriter := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) var ip net.IP if len(*pingIPStr) > 0 { @@ -70,21 +76,20 @@ func executePing(cmd *base.Command, args []string) { if err != nil { base.Fatalf("Failed to dial tcp: %s", err) } - tlsConn := gotls.Client(tcpConn, &gotls.Config{ + tlsConn := GeneraticUClient(tcpConn, &gotls.Config{ InsecureSkipVerify: true, NextProtos: []string{"h2", "http/1.1"}, MaxVersion: gotls.VersionTLS13, MinVersion: gotls.VersionTLS12, - // Do not release tool before v5's refactor - // VerifyPeerCertificate: showCert(), }) err = tlsConn.Handshake() if err != nil { fmt.Println("Handshake failure: ", err) } else { fmt.Println("Handshake succeeded") - printTLSConnDetail(tlsConn) - printCertificates(tlsConn.ConnectionState().PeerCertificates) + printTLSConnDetail(tabWriter, tlsConn) + printCertificates(tabWriter, tlsConn.ConnectionState().PeerCertificates) + tabWriter.Flush() } tlsConn.Close() } @@ -96,21 +101,20 @@ func executePing(cmd *base.Command, args []string) { if err != nil { base.Fatalf("Failed to dial tcp: %s", err) } - tlsConn := gotls.Client(tcpConn, &gotls.Config{ + tlsConn := GeneraticUClient(tcpConn, &gotls.Config{ ServerName: domain, NextProtos: []string{"h2", "http/1.1"}, MaxVersion: gotls.VersionTLS13, MinVersion: gotls.VersionTLS12, - // Do not release tool before v5's refactor - // VerifyPeerCertificate: showCert(), }) err = tlsConn.Handshake() if err != nil { fmt.Println("Handshake failure: ", err) } else { fmt.Println("Handshake succeeded") - printTLSConnDetail(tlsConn) - printCertificates(tlsConn.ConnectionState().PeerCertificates) + printTLSConnDetail(tabWriter, tlsConn) + printCertificates(tabWriter, tlsConn.ConnectionState().PeerCertificates) + tabWriter.Flush() } tlsConn.Close() } @@ -119,51 +123,45 @@ func executePing(cmd *base.Command, args []string) { fmt.Println("TLS ping finished") } -func printCertificates(certs []*x509.Certificate) { +func printCertificates(tabWriter *tabwriter.Writer, certs []*x509.Certificate) { var leaf *x509.Certificate + var CAs []*x509.Certificate var length int for _, cert := range certs { length += len(cert.Raw) if len(cert.DNSNames) != 0 { leaf = cert + } else { + CAs = append(CAs, cert) } } - fmt.Println("Certificate chain's total length: ", length, "(certs count: "+strconv.Itoa(len(certs))+")") + fmt.Fprintf(tabWriter, "Certificate chain's total length:\t%d (certs count: %s)\n", length, strconv.Itoa(len(certs))) if leaf != nil { - fmt.Println("Cert's signature algorithm: ", leaf.SignatureAlgorithm.String()) - fmt.Println("Cert's publicKey algorithm: ", leaf.PublicKeyAlgorithm.String()) - fmt.Println("Cert's allowed domains: ", leaf.DNSNames) + fmt.Fprintf(tabWriter, "Cert's signature algorithm:\t%s\n", leaf.SignatureAlgorithm.String()) + fmt.Fprintf(tabWriter, "Cert's publicKey algorithm:\t%s\n", leaf.PublicKeyAlgorithm.String()) + fmt.Fprintf(tabWriter, "Cert's leaf SHA256:\t%s\n", hex.EncodeToString(GenerateCertHash(leaf))) + for _, ca := range CAs { + fmt.Fprintf(tabWriter, "Cert's CA <%s> SHA256:\t%s\n", ca.Subject.CommonName, hex.EncodeToString(GenerateCertHash(ca))) + } + fmt.Fprintf(tabWriter, "Cert's allowed domains:\t%v\n", leaf.DNSNames) } } -func printTLSConnDetail(tlsConn *gotls.Conn) { +func printTLSConnDetail(tabWriter *tabwriter.Writer, tlsConn *utls.UConn) { connectionState := tlsConn.ConnectionState() var tlsVersion string - if connectionState.Version == gotls.VersionTLS13 { + switch connectionState.Version { + case gotls.VersionTLS13: tlsVersion = "TLS 1.3" - } else if connectionState.Version == gotls.VersionTLS12 { + case gotls.VersionTLS12: tlsVersion = "TLS 1.2" } - fmt.Println("TLS Version: ", tlsVersion) - curveID := connectionState.CurveID - if curveID != 0 { - PostQuantum := (curveID == gotls.X25519MLKEM768) - fmt.Println("TLS Post-Quantum key exchange: ", PostQuantum, "("+curveID.String()+")") + fmt.Fprintf(tabWriter, "TLS Version:\t%s\n", tlsVersion) + curveID := utils.AccessField[utls.CurveID](tlsConn.Conn, "curveID") + if curveID != nil { + PostQuantum := (*curveID == utls.X25519MLKEM768) + fmt.Fprintf(tabWriter, "TLS Post-Quantum key exchange:\t%t (%s)\n", PostQuantum, curveID.String()) } else { - fmt.Println("TLS Post-Quantum key exchange: false (RSA Exchange)") - } -} - -func showCert() func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - var hash []byte - for _, asn1Data := range rawCerts { - cert, _ := x509.ParseCertificate(asn1Data) - if cert.IsCA { - hash = GenerateCertHash(cert) - } - } - fmt.Println("Certificate Leaf Hash: ", hex.EncodeToString(hash)) - return nil + fmt.Fprintf(tabWriter, "TLS Post-Quantum key exchange: false (RSA Exchange)\n") } } diff --git a/main/commands/all/tls/tls.go b/main/commands/all/tls/tls.go index 05aaff51..c805684a 100644 --- a/main/commands/all/tls/tls.go +++ b/main/commands/all/tls/tls.go @@ -13,7 +13,7 @@ var CmdTLS = &base.Command{ Commands: []*base.Command{ cmdCert, cmdPing, - cmdLeafCertHash, + cmdHash, cmdECH, }, } diff --git a/main/distro/all/all.go b/main/distro/all/all.go index 4930c908..0622e41a 100644 --- a/main/distro/all/all.go +++ b/main/distro/all/all.go @@ -63,11 +63,6 @@ import ( // Transport headers _ "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/http" _ "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/noop" - _ "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/srtp" - _ "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/tls" - _ "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/utp" - _ "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/wechat" - _ "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/wireguard" // JSON & TOML & YAML _ "github.com/amnezia-vpn/amnezia-xray-core/main/json" diff --git a/proxy/http/client.go b/proxy/http/client.go index 8837f442..7bac026b 100644 --- a/proxy/http/client.go +++ b/proxy/http/client.go @@ -21,6 +21,7 @@ import ( "github.com/amnezia-vpn/amnezia-xray-core/common/session" "github.com/amnezia-vpn/amnezia-xray-core/common/signal" "github.com/amnezia-vpn/amnezia-xray-core/common/task" + "github.com/amnezia-vpn/amnezia-xray-core/common/utils" "github.com/amnezia-vpn/amnezia-xray-core/core" "github.com/amnezia-vpn/amnezia-xray-core/features/policy" "github.com/amnezia-vpn/amnezia-xray-core/transport" @@ -219,6 +220,9 @@ func setUpHTTPTunnel(ctx context.Context, dest net.Destination, target string, u for _, h := range header { req.Header.Set(h.Key, h.Value) } + if req.Header.Get("User-Agent") == "" { + req.Header.Set("User-Agent", utils.ChromeUA) + } connectHTTP1 := func(rawConn net.Conn) (net.Conn, error) { req.Header.Set("Proxy-Connection", "Keep-Alive") diff --git a/transport/internet/finalmask/finalmask.go b/transport/internet/finalmask/finalmask.go index 7ce4d4f3..d8a289a7 100644 --- a/transport/internet/finalmask/finalmask.go +++ b/transport/internet/finalmask/finalmask.go @@ -4,17 +4,15 @@ import ( "net" ) +type ConnSize interface { + Size() int32 +} + type Udpmask interface { UDP() - WrapConnClient(net.Conn) (net.Conn, error) - WrapConnServer(net.Conn) (net.Conn, error) - - WrapPacketConnClient(net.PacketConn) (net.PacketConn, error) - WrapPacketConnServer(net.PacketConn) (net.PacketConn, error) - - Size() int - Serialize([]byte) + WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) + WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) } type UdpmaskManager struct { @@ -27,66 +25,32 @@ func NewUdpmaskManager(udpmasks []Udpmask) *UdpmaskManager { } } -func (m *UdpmaskManager) WrapConnClient(raw net.Conn) (net.Conn, error) { - var err error - for _, mask := range m.udpmasks { - raw, err = mask.WrapConnClient(raw) - if err != nil { - return nil, err - } - } - return raw, nil -} - -func (m *UdpmaskManager) WrapConnServer(raw net.Conn) (net.Conn, error) { - var err error - for _, mask := range m.udpmasks { - raw, err = mask.WrapConnServer(raw) - if err != nil { - return nil, err - } - } - return raw, nil -} - func (m *UdpmaskManager) WrapPacketConnClient(raw net.PacketConn) (net.PacketConn, error) { + leaveSize := int32(0) var err error - for _, mask := range m.udpmasks { - raw, err = mask.WrapPacketConnClient(raw) + for i, mask := range m.udpmasks { + raw, err = mask.WrapPacketConnClient(raw, i == len(m.udpmasks)-1, leaveSize, i == 0) if err != nil { return nil, err } + leaveSize += raw.(ConnSize).Size() } return raw, nil } func (m *UdpmaskManager) WrapPacketConnServer(raw net.PacketConn) (net.PacketConn, error) { + leaveSize := int32(0) var err error - for _, mask := range m.udpmasks { - raw, err = mask.WrapPacketConnServer(raw) + for i, mask := range m.udpmasks { + raw, err = mask.WrapPacketConnServer(raw, i == len(m.udpmasks)-1, leaveSize, i == 0) if err != nil { return nil, err } + leaveSize += raw.(ConnSize).Size() } return raw, nil } -func (m *UdpmaskManager) Size() int { - size := 0 - for _, mask := range m.udpmasks { - size += mask.Size() - } - return size -} - -func (m *UdpmaskManager) Serialize(b []byte) { - index := 0 - for _, mask := range m.udpmasks { - mask.Serialize(b[index:]) - index += mask.Size() - } -} - type Tcpmask interface { TCP() diff --git a/transport/internet/finalmask/header/dns/config.go b/transport/internet/finalmask/header/dns/config.go new file mode 100644 index 00000000..d5aa5cc3 --- /dev/null +++ b/transport/internet/finalmask/header/dns/config.go @@ -0,0 +1,16 @@ +package dns + +import ( + "net" +) + +func (c *Config) UDP() { +} + +func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnServer(c, raw, first, leaveSize) +} diff --git a/transport/internet/finalmask/header/dns/config.pb.go b/transport/internet/finalmask/header/dns/config.pb.go new file mode 100644 index 00000000..bcc4ab9b --- /dev/null +++ b/transport/internet/finalmask/header/dns/config.pb.go @@ -0,0 +1,123 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v3.21.12 +// source: transport/internet/finalmask/header/dns/config.proto + +package dns + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_finalmask_header_dns_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_header_dns_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_header_dns_config_proto_rawDescGZIP(), []int{0} +} + +func (x *Config) GetDomain() string { + if x != nil { + return x.Domain + } + return "" +} + +var File_transport_internet_finalmask_header_dns_config_proto protoreflect.FileDescriptor + +const file_transport_internet_finalmask_header_dns_config_proto_rawDesc = "" + + "\n" + + "4transport/internet/finalmask/header/dns/config.proto\x12,xray.transport.internet.finalmask.header.dns\" \n" + + "\x06Config\x12\x16\n" + + "\x06domain\x18\x01 \x01(\tR\x06domainB\xb5\x01\n" + + "0com.xray.transport.internet.finalmask.header.dnsP\x01ZPgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/dns\xaa\x02,Xray.Transport.Internet.Finalmask.Header.Dnsb\x06proto3" + +var ( + file_transport_internet_finalmask_header_dns_config_proto_rawDescOnce sync.Once + file_transport_internet_finalmask_header_dns_config_proto_rawDescData []byte +) + +func file_transport_internet_finalmask_header_dns_config_proto_rawDescGZIP() []byte { + file_transport_internet_finalmask_header_dns_config_proto_rawDescOnce.Do(func() { + file_transport_internet_finalmask_header_dns_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_dns_config_proto_rawDesc), len(file_transport_internet_finalmask_header_dns_config_proto_rawDesc))) + }) + return file_transport_internet_finalmask_header_dns_config_proto_rawDescData +} + +var file_transport_internet_finalmask_header_dns_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_finalmask_header_dns_config_proto_goTypes = []any{ + (*Config)(nil), // 0: xray.transport.internet.finalmask.header.dns.Config +} +var file_transport_internet_finalmask_header_dns_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_transport_internet_finalmask_header_dns_config_proto_init() } +func file_transport_internet_finalmask_header_dns_config_proto_init() { + if File_transport_internet_finalmask_header_dns_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_dns_config_proto_rawDesc), len(file_transport_internet_finalmask_header_dns_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_finalmask_header_dns_config_proto_goTypes, + DependencyIndexes: file_transport_internet_finalmask_header_dns_config_proto_depIdxs, + MessageInfos: file_transport_internet_finalmask_header_dns_config_proto_msgTypes, + }.Build() + File_transport_internet_finalmask_header_dns_config_proto = out.File + file_transport_internet_finalmask_header_dns_config_proto_goTypes = nil + file_transport_internet_finalmask_header_dns_config_proto_depIdxs = nil +} diff --git a/transport/internet/finalmask/header/dns/config.proto b/transport/internet/finalmask/header/dns/config.proto new file mode 100644 index 00000000..0ce54e50 --- /dev/null +++ b/transport/internet/finalmask/header/dns/config.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package xray.transport.internet.finalmask.header.dns; +option csharp_namespace = "Xray.Transport.Internet.Finalmask.Header.Dns"; +option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/dns"; +option java_package = "com.xray.transport.internet.finalmask.header.dns"; +option java_multiple_files = true; + +message Config { + string domain = 1; +} diff --git a/transport/internet/finalmask/header/dns/conn.go b/transport/internet/finalmask/header/dns/conn.go new file mode 100644 index 00000000..426bef84 --- /dev/null +++ b/transport/internet/finalmask/header/dns/conn.go @@ -0,0 +1,241 @@ +package dns + +import ( + "encoding/binary" + "io" + "net" + sync "sync" + "time" + + "github.com/amnezia-vpn/amnezia-xray-core/common/dice" + "github.com/amnezia-vpn/amnezia-xray-core/common/errors" +) + +func packDomainName(s string, msg []byte) (off1 int, err error) { + off := 0 + ls := len(s) + // Each dot ends a segment of the name. + // We trade each dot byte for a length byte. + // Except for escaped dots (\.), which are normal dots. + // There is also a trailing zero. + + // Emit sequence of counted strings, chopping at dots. + var ( + begin int + bs []byte + ) + for i := 0; i < ls; i++ { + var c byte + if bs == nil { + c = s[i] + } else { + c = bs[i] + } + + switch c { + case '\\': + if off+1 > len(msg) { + return len(msg), errors.New("buffer size too small") + } + + if bs == nil { + bs = []byte(s) + } + + copy(bs[i:ls-1], bs[i+1:]) + ls-- + case '.': + labelLen := i - begin + if labelLen >= 1<<6 { // top two bits of length must be clear + return len(msg), errors.New("bad rdata") + } + + // off can already (we're in a loop) be bigger than len(msg) + // this happens when a name isn't fully qualified + if off+1+labelLen > len(msg) { + return len(msg), errors.New("buffer size too small") + } + + // The following is covered by the length check above. + msg[off] = byte(labelLen) + + if bs == nil { + copy(msg[off+1:], s[begin:i]) + } else { + copy(msg[off+1:], bs[begin:i]) + } + off += 1 + labelLen + begin = i + 1 + default: + } + } + + if off < len(msg) { + msg[off] = 0 + } + + return off + 1, nil +} + +type dns struct { + header []byte +} + +func (h *dns) Size() int32 { + return int32(len(h.header)) +} + +func (h *dns) Serialize(b []byte) { + copy(b, h.header) + binary.BigEndian.PutUint16(b[0:], dice.RollUint16()) +} + +type dnsConn struct { + first bool + leaveSize int32 + + conn net.PacketConn + header *dns + + readBuf []byte + readMutex sync.Mutex + writeBuf []byte + writeMutex sync.Mutex +} + +func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + var header []byte + header = binary.BigEndian.AppendUint16(header, 0x0000) // Transaction ID + header = binary.BigEndian.AppendUint16(header, 0x0100) // Flags: Standard query + header = binary.BigEndian.AppendUint16(header, 0x0001) // Questions + header = binary.BigEndian.AppendUint16(header, 0x0000) // Answer RRs + header = binary.BigEndian.AppendUint16(header, 0x0000) // Authority RRs + header = binary.BigEndian.AppendUint16(header, 0x0000) // Additional RRs + buf := make([]byte, 0x100) + off1, err := packDomainName(c.Domain+".", buf) + if err != nil { + return nil, err + } + header = append(header, buf[:off1]...) + header = binary.BigEndian.AppendUint16(header, 0x0001) // Type: A + header = binary.BigEndian.AppendUint16(header, 0x0001) // Class: IN + + conn := &dnsConn{ + first: first, + leaveSize: leaveSize, + + conn: raw, + header: &dns{ + header: header, + }, + } + + if first { + conn.readBuf = make([]byte, 8192) + conn.writeBuf = make([]byte, 8192) + } + + return conn, nil +} + +func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *dnsConn) Size() int32 { + return c.header.Size() +} + +func (c *dnsConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if c.first { + c.readMutex.Lock() + + n, addr, err = c.conn.ReadFrom(c.readBuf) + if err != nil { + c.readMutex.Unlock() + return n, addr, err + } + + if n < int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + if len(p) < n-int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + copy(p, c.readBuf[c.Size():n]) + + c.readMutex.Unlock() + return n - int(c.Size()), addr, err + } + + n, addr, err = c.conn.ReadFrom(p) + if err != nil { + return n, addr, err + } + + if n < int(c.Size()) { + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + copy(p, p[c.Size():n]) + + return n - int(c.Size()), addr, err +} + +func (c *dnsConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.first { + if c.leaveSize+c.Size()+int32(len(p)) > 8192 { + return 0, errors.New("too many masks") + } + + c.writeMutex.Lock() + + n = copy(c.writeBuf[c.leaveSize+c.Size():], p) + n += int(c.leaveSize) + int(c.Size()) + + c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()]) + + nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) + + if err != nil { + c.writeMutex.Unlock() + return 0, err + } + + if nn != n { + c.writeMutex.Unlock() + return 0, errors.New("nn != n") + } + + c.writeMutex.Unlock() + return len(p), nil + } + + c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()]) + + return c.conn.WriteTo(p, addr) +} + +func (c *dnsConn) Close() error { + return c.conn.Close() +} + +func (c *dnsConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *dnsConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *dnsConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *dnsConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} diff --git a/transport/internet/finalmask/header/dtls/config.go b/transport/internet/finalmask/header/dtls/config.go new file mode 100644 index 00000000..ccce33de --- /dev/null +++ b/transport/internet/finalmask/header/dtls/config.go @@ -0,0 +1,16 @@ +package dtls + +import ( + "net" +) + +func (c *Config) UDP() { +} + +func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnServer(c, raw, first, leaveSize) +} diff --git a/transport/internet/finalmask/header/dtls/config.pb.go b/transport/internet/finalmask/header/dtls/config.pb.go new file mode 100644 index 00000000..9efcc72d --- /dev/null +++ b/transport/internet/finalmask/header/dtls/config.pb.go @@ -0,0 +1,114 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v3.21.12 +// source: transport/internet/finalmask/header/dtls/config.proto + +package dtls + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_finalmask_header_dtls_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_header_dtls_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_header_dtls_config_proto_rawDescGZIP(), []int{0} +} + +var File_transport_internet_finalmask_header_dtls_config_proto protoreflect.FileDescriptor + +const file_transport_internet_finalmask_header_dtls_config_proto_rawDesc = "" + + "\n" + + "5transport/internet/finalmask/header/dtls/config.proto\x12-xray.transport.internet.finalmask.header.dtls\"\b\n" + + "\x06ConfigB\xb8\x01\n" + + "1com.xray.transport.internet.finalmask.header.dtlsP\x01ZQgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/dtls\xaa\x02-Xray.Transport.Internet.Finalmask.Header.Dtlsb\x06proto3" + +var ( + file_transport_internet_finalmask_header_dtls_config_proto_rawDescOnce sync.Once + file_transport_internet_finalmask_header_dtls_config_proto_rawDescData []byte +) + +func file_transport_internet_finalmask_header_dtls_config_proto_rawDescGZIP() []byte { + file_transport_internet_finalmask_header_dtls_config_proto_rawDescOnce.Do(func() { + file_transport_internet_finalmask_header_dtls_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_dtls_config_proto_rawDesc), len(file_transport_internet_finalmask_header_dtls_config_proto_rawDesc))) + }) + return file_transport_internet_finalmask_header_dtls_config_proto_rawDescData +} + +var file_transport_internet_finalmask_header_dtls_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_finalmask_header_dtls_config_proto_goTypes = []any{ + (*Config)(nil), // 0: xray.transport.internet.finalmask.header.dtls.Config +} +var file_transport_internet_finalmask_header_dtls_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_transport_internet_finalmask_header_dtls_config_proto_init() } +func file_transport_internet_finalmask_header_dtls_config_proto_init() { + if File_transport_internet_finalmask_header_dtls_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_dtls_config_proto_rawDesc), len(file_transport_internet_finalmask_header_dtls_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_finalmask_header_dtls_config_proto_goTypes, + DependencyIndexes: file_transport_internet_finalmask_header_dtls_config_proto_depIdxs, + MessageInfos: file_transport_internet_finalmask_header_dtls_config_proto_msgTypes, + }.Build() + File_transport_internet_finalmask_header_dtls_config_proto = out.File + file_transport_internet_finalmask_header_dtls_config_proto_goTypes = nil + file_transport_internet_finalmask_header_dtls_config_proto_depIdxs = nil +} diff --git a/transport/internet/finalmask/header/dtls/config.proto b/transport/internet/finalmask/header/dtls/config.proto new file mode 100644 index 00000000..28ede2c2 --- /dev/null +++ b/transport/internet/finalmask/header/dtls/config.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package xray.transport.internet.finalmask.header.dtls; +option csharp_namespace = "Xray.Transport.Internet.Finalmask.Header.Dtls"; +option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/dtls"; +option java_package = "com.xray.transport.internet.finalmask.header.dtls"; +option java_multiple_files = true; + +message Config {} diff --git a/transport/internet/finalmask/header/dtls/conn.go b/transport/internet/finalmask/header/dtls/conn.go new file mode 100644 index 00000000..f26297e0 --- /dev/null +++ b/transport/internet/finalmask/header/dtls/conn.go @@ -0,0 +1,178 @@ +package dtls + +import ( + "io" + "net" + sync "sync" + "time" + + "github.com/amnezia-vpn/amnezia-xray-core/common/dice" + "github.com/amnezia-vpn/amnezia-xray-core/common/errors" +) + +type dtls struct { + epoch uint16 + length uint16 + sequence uint32 +} + +func (*dtls) Size() int32 { + return 1 + 2 + 2 + 6 + 2 +} + +func (h *dtls) Serialize(b []byte) { + b[0] = 23 + b[1] = 254 + b[2] = 253 + b[3] = byte(h.epoch >> 8) + b[4] = byte(h.epoch) + b[5] = 0 + b[6] = 0 + b[7] = byte(h.sequence >> 24) + b[8] = byte(h.sequence >> 16) + b[9] = byte(h.sequence >> 8) + b[10] = byte(h.sequence) + h.sequence++ + b[11] = byte(h.length >> 8) + b[12] = byte(h.length) + h.length += 17 + if h.length > 100 { + h.length -= 50 + } +} + +type dtlsConn struct { + first bool + leaveSize int32 + + conn net.PacketConn + header *dtls + + readBuf []byte + readMutex sync.Mutex + writeBuf []byte + writeMutex sync.Mutex +} + +func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + conn := &dtlsConn{ + first: first, + leaveSize: leaveSize, + + conn: raw, + header: &dtls{ + epoch: dice.RollUint16(), + sequence: 0, + length: 17, + }, + } + + if first { + conn.readBuf = make([]byte, 8192) + conn.writeBuf = make([]byte, 8192) + } + + return conn, nil +} + +func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *dtlsConn) Size() int32 { + return c.header.Size() +} + +func (c *dtlsConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if c.first { + c.readMutex.Lock() + + n, addr, err = c.conn.ReadFrom(c.readBuf) + if err != nil { + c.readMutex.Unlock() + return n, addr, err + } + + if n < int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + if len(p) < n-int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + copy(p, c.readBuf[c.Size():n]) + + c.readMutex.Unlock() + return n - int(c.Size()), addr, err + } + + n, addr, err = c.conn.ReadFrom(p) + if err != nil { + return n, addr, err + } + + if n < int(c.Size()) { + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + copy(p, p[c.Size():n]) + + return n - int(c.Size()), addr, err +} + +func (c *dtlsConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.first { + if c.leaveSize+c.Size()+int32(len(p)) > 8192 { + return 0, errors.New("too many masks") + } + + c.writeMutex.Lock() + + n = copy(c.writeBuf[c.leaveSize+c.Size():], p) + n += int(c.leaveSize) + int(c.Size()) + + c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()]) + + nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) + + if err != nil { + c.writeMutex.Unlock() + return 0, err + } + + if nn != n { + c.writeMutex.Unlock() + return 0, errors.New("nn != n") + } + + c.writeMutex.Unlock() + return len(p), nil + } + + c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()]) + + return c.conn.WriteTo(p, addr) +} + +func (c *dtlsConn) Close() error { + return c.conn.Close() +} + +func (c *dtlsConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *dtlsConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *dtlsConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *dtlsConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} diff --git a/transport/internet/finalmask/header/srtp/config.go b/transport/internet/finalmask/header/srtp/config.go new file mode 100644 index 00000000..45def616 --- /dev/null +++ b/transport/internet/finalmask/header/srtp/config.go @@ -0,0 +1,16 @@ +package srtp + +import ( + "net" +) + +func (c *Config) UDP() { +} + +func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnServer(c, raw, first, leaveSize) +} diff --git a/transport/internet/finalmask/header/srtp/config.pb.go b/transport/internet/finalmask/header/srtp/config.pb.go new file mode 100644 index 00000000..17df58b9 --- /dev/null +++ b/transport/internet/finalmask/header/srtp/config.pb.go @@ -0,0 +1,114 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v3.21.12 +// source: transport/internet/finalmask/header/srtp/config.proto + +package srtp + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_finalmask_header_srtp_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_header_srtp_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_header_srtp_config_proto_rawDescGZIP(), []int{0} +} + +var File_transport_internet_finalmask_header_srtp_config_proto protoreflect.FileDescriptor + +const file_transport_internet_finalmask_header_srtp_config_proto_rawDesc = "" + + "\n" + + "5transport/internet/finalmask/header/srtp/config.proto\x12-xray.transport.internet.finalmask.header.srtp\"\b\n" + + "\x06ConfigB\xb8\x01\n" + + "1com.xray.transport.internet.finalmask.header.srtpP\x01ZQgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/srtp\xaa\x02-Xray.Transport.Internet.Finalmask.Header.Srtpb\x06proto3" + +var ( + file_transport_internet_finalmask_header_srtp_config_proto_rawDescOnce sync.Once + file_transport_internet_finalmask_header_srtp_config_proto_rawDescData []byte +) + +func file_transport_internet_finalmask_header_srtp_config_proto_rawDescGZIP() []byte { + file_transport_internet_finalmask_header_srtp_config_proto_rawDescOnce.Do(func() { + file_transport_internet_finalmask_header_srtp_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_srtp_config_proto_rawDesc), len(file_transport_internet_finalmask_header_srtp_config_proto_rawDesc))) + }) + return file_transport_internet_finalmask_header_srtp_config_proto_rawDescData +} + +var file_transport_internet_finalmask_header_srtp_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_finalmask_header_srtp_config_proto_goTypes = []any{ + (*Config)(nil), // 0: xray.transport.internet.finalmask.header.srtp.Config +} +var file_transport_internet_finalmask_header_srtp_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_transport_internet_finalmask_header_srtp_config_proto_init() } +func file_transport_internet_finalmask_header_srtp_config_proto_init() { + if File_transport_internet_finalmask_header_srtp_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_srtp_config_proto_rawDesc), len(file_transport_internet_finalmask_header_srtp_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_finalmask_header_srtp_config_proto_goTypes, + DependencyIndexes: file_transport_internet_finalmask_header_srtp_config_proto_depIdxs, + MessageInfos: file_transport_internet_finalmask_header_srtp_config_proto_msgTypes, + }.Build() + File_transport_internet_finalmask_header_srtp_config_proto = out.File + file_transport_internet_finalmask_header_srtp_config_proto_goTypes = nil + file_transport_internet_finalmask_header_srtp_config_proto_depIdxs = nil +} diff --git a/transport/internet/finalmask/header/srtp/config.proto b/transport/internet/finalmask/header/srtp/config.proto new file mode 100644 index 00000000..151f28a8 --- /dev/null +++ b/transport/internet/finalmask/header/srtp/config.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package xray.transport.internet.finalmask.header.srtp; +option csharp_namespace = "Xray.Transport.Internet.Finalmask.Header.Srtp"; +option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/srtp"; +option java_package = "com.xray.transport.internet.finalmask.header.srtp"; +option java_multiple_files = true; + +message Config {} diff --git a/transport/internet/finalmask/header/srtp/conn.go b/transport/internet/finalmask/header/srtp/conn.go new file mode 100644 index 00000000..0bc8adc2 --- /dev/null +++ b/transport/internet/finalmask/header/srtp/conn.go @@ -0,0 +1,162 @@ +package srtp + +import ( + "encoding/binary" + "io" + "net" + sync "sync" + "time" + + "github.com/amnezia-vpn/amnezia-xray-core/common/dice" + "github.com/amnezia-vpn/amnezia-xray-core/common/errors" +) + +type srtp struct { + header uint16 + number uint16 +} + +func (*srtp) Size() int32 { + return 4 +} + +func (h *srtp) Serialize(b []byte) { + h.number++ + binary.BigEndian.PutUint16(b, h.header) + binary.BigEndian.PutUint16(b[2:], h.number) +} + +type srtpConn struct { + first bool + leaveSize int32 + + conn net.PacketConn + header *srtp + + readBuf []byte + readMutex sync.Mutex + writeBuf []byte + writeMutex sync.Mutex +} + +func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + conn := &srtpConn{ + first: first, + leaveSize: leaveSize, + + conn: raw, + header: &srtp{ + header: 0xB5E8, + number: dice.RollUint16(), + }, + } + + if first { + conn.readBuf = make([]byte, 8192) + conn.writeBuf = make([]byte, 8192) + } + + return conn, nil +} + +func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *srtpConn) Size() int32 { + return c.header.Size() +} + +func (c *srtpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if c.first { + c.readMutex.Lock() + + n, addr, err = c.conn.ReadFrom(c.readBuf) + if err != nil { + c.readMutex.Unlock() + return n, addr, err + } + + if n < int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + if len(p) < n-int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + copy(p, c.readBuf[c.Size():n]) + + c.readMutex.Unlock() + return n - int(c.Size()), addr, err + } + + n, addr, err = c.conn.ReadFrom(p) + if err != nil { + return n, addr, err + } + + if n < int(c.Size()) { + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + copy(p, p[c.Size():n]) + + return n - int(c.Size()), addr, err +} + +func (c *srtpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.first { + if c.leaveSize+c.Size()+int32(len(p)) > 8192 { + return 0, errors.New("too many masks") + } + + c.writeMutex.Lock() + + n = copy(c.writeBuf[c.leaveSize+c.Size():], p) + n += int(c.leaveSize) + int(c.Size()) + + c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()]) + + nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) + + if err != nil { + c.writeMutex.Unlock() + return 0, err + } + + if nn != n { + c.writeMutex.Unlock() + return 0, errors.New("nn != n") + } + + c.writeMutex.Unlock() + return len(p), nil + } + + c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()]) + + return c.conn.WriteTo(p, addr) +} + +func (c *srtpConn) Close() error { + return c.conn.Close() +} + +func (c *srtpConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *srtpConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *srtpConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *srtpConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} diff --git a/transport/internet/finalmask/header/utp/config.go b/transport/internet/finalmask/header/utp/config.go new file mode 100644 index 00000000..a579d483 --- /dev/null +++ b/transport/internet/finalmask/header/utp/config.go @@ -0,0 +1,16 @@ +package utp + +import ( + "net" +) + +func (c *Config) UDP() { +} + +func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnServer(c, raw, first, leaveSize) +} diff --git a/transport/internet/finalmask/header/utp/config.pb.go b/transport/internet/finalmask/header/utp/config.pb.go new file mode 100644 index 00000000..b32e5e87 --- /dev/null +++ b/transport/internet/finalmask/header/utp/config.pb.go @@ -0,0 +1,114 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v3.21.12 +// source: transport/internet/finalmask/header/utp/config.proto + +package utp + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_finalmask_header_utp_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_header_utp_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_header_utp_config_proto_rawDescGZIP(), []int{0} +} + +var File_transport_internet_finalmask_header_utp_config_proto protoreflect.FileDescriptor + +const file_transport_internet_finalmask_header_utp_config_proto_rawDesc = "" + + "\n" + + "4transport/internet/finalmask/header/utp/config.proto\x12,xray.transport.internet.finalmask.header.utp\"\b\n" + + "\x06ConfigB\xb5\x01\n" + + "0com.xray.transport.internet.finalmask.header.utpP\x01ZPgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/utp\xaa\x02,Xray.Transport.Internet.Finalmask.Header.Utpb\x06proto3" + +var ( + file_transport_internet_finalmask_header_utp_config_proto_rawDescOnce sync.Once + file_transport_internet_finalmask_header_utp_config_proto_rawDescData []byte +) + +func file_transport_internet_finalmask_header_utp_config_proto_rawDescGZIP() []byte { + file_transport_internet_finalmask_header_utp_config_proto_rawDescOnce.Do(func() { + file_transport_internet_finalmask_header_utp_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_utp_config_proto_rawDesc), len(file_transport_internet_finalmask_header_utp_config_proto_rawDesc))) + }) + return file_transport_internet_finalmask_header_utp_config_proto_rawDescData +} + +var file_transport_internet_finalmask_header_utp_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_finalmask_header_utp_config_proto_goTypes = []any{ + (*Config)(nil), // 0: xray.transport.internet.finalmask.header.utp.Config +} +var file_transport_internet_finalmask_header_utp_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_transport_internet_finalmask_header_utp_config_proto_init() } +func file_transport_internet_finalmask_header_utp_config_proto_init() { + if File_transport_internet_finalmask_header_utp_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_utp_config_proto_rawDesc), len(file_transport_internet_finalmask_header_utp_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_finalmask_header_utp_config_proto_goTypes, + DependencyIndexes: file_transport_internet_finalmask_header_utp_config_proto_depIdxs, + MessageInfos: file_transport_internet_finalmask_header_utp_config_proto_msgTypes, + }.Build() + File_transport_internet_finalmask_header_utp_config_proto = out.File + file_transport_internet_finalmask_header_utp_config_proto_goTypes = nil + file_transport_internet_finalmask_header_utp_config_proto_depIdxs = nil +} diff --git a/transport/internet/finalmask/header/utp/config.proto b/transport/internet/finalmask/header/utp/config.proto new file mode 100644 index 00000000..ded50f8c --- /dev/null +++ b/transport/internet/finalmask/header/utp/config.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package xray.transport.internet.finalmask.header.utp; +option csharp_namespace = "Xray.Transport.Internet.Finalmask.Header.Utp"; +option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/utp"; +option java_package = "com.xray.transport.internet.finalmask.header.utp"; +option java_multiple_files = true; + +message Config {} diff --git a/transport/internet/finalmask/header/utp/conn.go b/transport/internet/finalmask/header/utp/conn.go new file mode 100644 index 00000000..66939e7b --- /dev/null +++ b/transport/internet/finalmask/header/utp/conn.go @@ -0,0 +1,164 @@ +package utp + +import ( + "encoding/binary" + "io" + "net" + sync "sync" + "time" + + "github.com/amnezia-vpn/amnezia-xray-core/common/dice" + "github.com/amnezia-vpn/amnezia-xray-core/common/errors" +) + +type utp struct { + header byte + extension byte + connectionID uint16 +} + +func (*utp) Size() int32 { + return 4 +} + +func (h *utp) Serialize(b []byte) { + binary.BigEndian.PutUint16(b, h.connectionID) + b[2] = h.header + b[3] = h.extension +} + +type utpConn struct { + first bool + leaveSize int32 + + conn net.PacketConn + header *utp + + readBuf []byte + readMutex sync.Mutex + writeBuf []byte + writeMutex sync.Mutex +} + +func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + conn := &utpConn{ + first: first, + leaveSize: leaveSize, + + conn: raw, + header: &utp{ + header: 1, + extension: 0, + connectionID: dice.RollUint16(), + }, + } + + if first { + conn.readBuf = make([]byte, 8192) + conn.writeBuf = make([]byte, 8192) + } + + return conn, nil +} + +func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *utpConn) Size() int32 { + return c.header.Size() +} + +func (c *utpConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if c.first { + c.readMutex.Lock() + + n, addr, err = c.conn.ReadFrom(c.readBuf) + if err != nil { + c.readMutex.Unlock() + return n, addr, err + } + + if n < int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + if len(p) < n-int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + copy(p, c.readBuf[c.Size():n]) + + c.readMutex.Unlock() + return n - int(c.Size()), addr, err + } + + n, addr, err = c.conn.ReadFrom(p) + if err != nil { + return n, addr, err + } + + if n < int(c.Size()) { + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + copy(p, p[c.Size():n]) + + return n - int(c.Size()), addr, err +} + +func (c *utpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.first { + if c.leaveSize+c.Size()+int32(len(p)) > 8192 { + return 0, errors.New("too many masks") + } + + c.writeMutex.Lock() + + n = copy(c.writeBuf[c.leaveSize+c.Size():], p) + n += int(c.leaveSize) + int(c.Size()) + + c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()]) + + nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) + + if err != nil { + c.writeMutex.Unlock() + return 0, err + } + + if nn != n { + c.writeMutex.Unlock() + return 0, errors.New("nn != n") + } + + c.writeMutex.Unlock() + return len(p), nil + } + + c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()]) + + return c.conn.WriteTo(p, addr) +} + +func (c *utpConn) Close() error { + return c.conn.Close() +} + +func (c *utpConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *utpConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *utpConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *utpConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} diff --git a/transport/internet/finalmask/header/wechat/config.go b/transport/internet/finalmask/header/wechat/config.go new file mode 100644 index 00000000..34971ace --- /dev/null +++ b/transport/internet/finalmask/header/wechat/config.go @@ -0,0 +1,16 @@ +package wechat + +import ( + "net" +) + +func (c *Config) UDP() { +} + +func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnServer(c, raw, first, leaveSize) +} diff --git a/transport/internet/finalmask/header/wechat/config.pb.go b/transport/internet/finalmask/header/wechat/config.pb.go new file mode 100644 index 00000000..4daa18b2 --- /dev/null +++ b/transport/internet/finalmask/header/wechat/config.pb.go @@ -0,0 +1,114 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v3.21.12 +// source: transport/internet/finalmask/header/wechat/config.proto + +package wechat + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_finalmask_header_wechat_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_header_wechat_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_header_wechat_config_proto_rawDescGZIP(), []int{0} +} + +var File_transport_internet_finalmask_header_wechat_config_proto protoreflect.FileDescriptor + +const file_transport_internet_finalmask_header_wechat_config_proto_rawDesc = "" + + "\n" + + "7transport/internet/finalmask/header/wechat/config.proto\x12/xray.transport.internet.finalmask.header.wechat\"\b\n" + + "\x06ConfigB\xbe\x01\n" + + "3com.xray.transport.internet.finalmask.header.wechatP\x01ZSgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/wechat\xaa\x02/Xray.Transport.Internet.Finalmask.Header.Wechatb\x06proto3" + +var ( + file_transport_internet_finalmask_header_wechat_config_proto_rawDescOnce sync.Once + file_transport_internet_finalmask_header_wechat_config_proto_rawDescData []byte +) + +func file_transport_internet_finalmask_header_wechat_config_proto_rawDescGZIP() []byte { + file_transport_internet_finalmask_header_wechat_config_proto_rawDescOnce.Do(func() { + file_transport_internet_finalmask_header_wechat_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_wechat_config_proto_rawDesc), len(file_transport_internet_finalmask_header_wechat_config_proto_rawDesc))) + }) + return file_transport_internet_finalmask_header_wechat_config_proto_rawDescData +} + +var file_transport_internet_finalmask_header_wechat_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_finalmask_header_wechat_config_proto_goTypes = []any{ + (*Config)(nil), // 0: xray.transport.internet.finalmask.header.wechat.Config +} +var file_transport_internet_finalmask_header_wechat_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_transport_internet_finalmask_header_wechat_config_proto_init() } +func file_transport_internet_finalmask_header_wechat_config_proto_init() { + if File_transport_internet_finalmask_header_wechat_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_wechat_config_proto_rawDesc), len(file_transport_internet_finalmask_header_wechat_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_finalmask_header_wechat_config_proto_goTypes, + DependencyIndexes: file_transport_internet_finalmask_header_wechat_config_proto_depIdxs, + MessageInfos: file_transport_internet_finalmask_header_wechat_config_proto_msgTypes, + }.Build() + File_transport_internet_finalmask_header_wechat_config_proto = out.File + file_transport_internet_finalmask_header_wechat_config_proto_goTypes = nil + file_transport_internet_finalmask_header_wechat_config_proto_depIdxs = nil +} diff --git a/transport/internet/finalmask/header/wechat/config.proto b/transport/internet/finalmask/header/wechat/config.proto new file mode 100644 index 00000000..440c44a5 --- /dev/null +++ b/transport/internet/finalmask/header/wechat/config.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package xray.transport.internet.finalmask.header.wechat; +option csharp_namespace = "Xray.Transport.Internet.Finalmask.Header.Wechat"; +option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/wechat"; +option java_package = "com.xray.transport.internet.finalmask.header.wechat"; +option java_multiple_files = true; + +message Config {} diff --git a/transport/internet/finalmask/header/wechat/conn.go b/transport/internet/finalmask/header/wechat/conn.go new file mode 100644 index 00000000..7489475d --- /dev/null +++ b/transport/internet/finalmask/header/wechat/conn.go @@ -0,0 +1,168 @@ +package wechat + +import ( + "encoding/binary" + "io" + "net" + sync "sync" + "time" + + "github.com/amnezia-vpn/amnezia-xray-core/common/dice" + "github.com/amnezia-vpn/amnezia-xray-core/common/errors" +) + +type wechat struct { + sn uint32 +} + +func (*wechat) Size() int32 { + return 13 +} + +func (h *wechat) Serialize(b []byte) { + h.sn++ + b[0] = 0xa1 + b[1] = 0x08 + binary.BigEndian.PutUint32(b[2:], h.sn) + b[6] = 0x00 + b[7] = 0x10 + b[8] = 0x11 + b[9] = 0x18 + b[10] = 0x30 + b[11] = 0x22 + b[12] = 0x30 +} + +type wechatConn struct { + first bool + leaveSize int32 + + conn net.PacketConn + header *wechat + + readBuf []byte + readMutex sync.Mutex + writeBuf []byte + writeMutex sync.Mutex +} + +func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + conn := &wechatConn{ + first: first, + leaveSize: leaveSize, + + conn: raw, + header: &wechat{ + sn: uint32(dice.RollUint16()), + }, + } + + if first { + conn.readBuf = make([]byte, 8192) + conn.writeBuf = make([]byte, 8192) + } + + return conn, nil +} + +func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *wechatConn) Size() int32 { + return c.header.Size() +} + +func (c *wechatConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if c.first { + c.readMutex.Lock() + + n, addr, err = c.conn.ReadFrom(c.readBuf) + if err != nil { + c.readMutex.Unlock() + return n, addr, err + } + + if n < int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + if len(p) < n-int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + copy(p, c.readBuf[c.Size():n]) + + c.readMutex.Unlock() + return n - int(c.Size()), addr, err + } + + n, addr, err = c.conn.ReadFrom(p) + if err != nil { + return n, addr, err + } + + if n < int(c.Size()) { + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + copy(p, p[c.Size():n]) + + return n - int(c.Size()), addr, err +} + +func (c *wechatConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.first { + if c.leaveSize+c.Size()+int32(len(p)) > 8192 { + return 0, errors.New("too many masks") + } + + c.writeMutex.Lock() + + n = copy(c.writeBuf[c.leaveSize+c.Size():], p) + n += int(c.leaveSize) + int(c.Size()) + + c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()]) + + nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) + + if err != nil { + c.writeMutex.Unlock() + return 0, err + } + + if nn != n { + c.writeMutex.Unlock() + return 0, errors.New("nn != n") + } + + c.writeMutex.Unlock() + return len(p), nil + } + + c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()]) + + return c.conn.WriteTo(p, addr) +} + +func (c *wechatConn) Close() error { + return c.conn.Close() +} + +func (c *wechatConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *wechatConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *wechatConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *wechatConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} diff --git a/transport/internet/finalmask/header/wireguard/config.go b/transport/internet/finalmask/header/wireguard/config.go new file mode 100644 index 00000000..5eeee34b --- /dev/null +++ b/transport/internet/finalmask/header/wireguard/config.go @@ -0,0 +1,16 @@ +package wireguard + +import ( + "net" +) + +func (c *Config) UDP() { +} + +func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnServer(c, raw, first, leaveSize) +} diff --git a/transport/internet/finalmask/header/wireguard/config.pb.go b/transport/internet/finalmask/header/wireguard/config.pb.go new file mode 100644 index 00000000..027d97ca --- /dev/null +++ b/transport/internet/finalmask/header/wireguard/config.pb.go @@ -0,0 +1,114 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v3.21.12 +// source: transport/internet/finalmask/header/wireguard/config.proto + +package wireguard + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_finalmask_header_wireguard_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_header_wireguard_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_header_wireguard_config_proto_rawDescGZIP(), []int{0} +} + +var File_transport_internet_finalmask_header_wireguard_config_proto protoreflect.FileDescriptor + +const file_transport_internet_finalmask_header_wireguard_config_proto_rawDesc = "" + + "\n" + + ":transport/internet/finalmask/header/wireguard/config.proto\x122xray.transport.internet.finalmask.header.wireguard\"\b\n" + + "\x06ConfigB\xc7\x01\n" + + "6com.xray.transport.internet.finalmask.header.wireguardP\x01ZVgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/wireguard\xaa\x022Xray.Transport.Internet.Finalmask.Header.Wireguardb\x06proto3" + +var ( + file_transport_internet_finalmask_header_wireguard_config_proto_rawDescOnce sync.Once + file_transport_internet_finalmask_header_wireguard_config_proto_rawDescData []byte +) + +func file_transport_internet_finalmask_header_wireguard_config_proto_rawDescGZIP() []byte { + file_transport_internet_finalmask_header_wireguard_config_proto_rawDescOnce.Do(func() { + file_transport_internet_finalmask_header_wireguard_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_wireguard_config_proto_rawDesc), len(file_transport_internet_finalmask_header_wireguard_config_proto_rawDesc))) + }) + return file_transport_internet_finalmask_header_wireguard_config_proto_rawDescData +} + +var file_transport_internet_finalmask_header_wireguard_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_finalmask_header_wireguard_config_proto_goTypes = []any{ + (*Config)(nil), // 0: xray.transport.internet.finalmask.header.wireguard.Config +} +var file_transport_internet_finalmask_header_wireguard_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_transport_internet_finalmask_header_wireguard_config_proto_init() } +func file_transport_internet_finalmask_header_wireguard_config_proto_init() { + if File_transport_internet_finalmask_header_wireguard_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_header_wireguard_config_proto_rawDesc), len(file_transport_internet_finalmask_header_wireguard_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_finalmask_header_wireguard_config_proto_goTypes, + DependencyIndexes: file_transport_internet_finalmask_header_wireguard_config_proto_depIdxs, + MessageInfos: file_transport_internet_finalmask_header_wireguard_config_proto_msgTypes, + }.Build() + File_transport_internet_finalmask_header_wireguard_config_proto = out.File + file_transport_internet_finalmask_header_wireguard_config_proto_goTypes = nil + file_transport_internet_finalmask_header_wireguard_config_proto_depIdxs = nil +} diff --git a/transport/internet/finalmask/header/wireguard/config.proto b/transport/internet/finalmask/header/wireguard/config.proto new file mode 100644 index 00000000..01917bcd --- /dev/null +++ b/transport/internet/finalmask/header/wireguard/config.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package xray.transport.internet.finalmask.header.wireguard; +option csharp_namespace = "Xray.Transport.Internet.Finalmask.Header.Wireguard"; +option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/wireguard"; +option java_package = "com.xray.transport.internet.finalmask.header.wireguard"; +option java_multiple_files = true; + +message Config {} diff --git a/transport/internet/finalmask/header/wireguard/conn.go b/transport/internet/finalmask/header/wireguard/conn.go new file mode 100644 index 00000000..8b01585e --- /dev/null +++ b/transport/internet/finalmask/header/wireguard/conn.go @@ -0,0 +1,155 @@ +package wireguard + +import ( + "io" + "net" + sync "sync" + "time" + + "github.com/amnezia-vpn/amnezia-xray-core/common/errors" +) + +type wireguare struct{} + +func (*wireguare) Size() int32 { + return 4 +} + +func (h *wireguare) Serialize(b []byte) { + b[0] = 0x04 + b[1] = 0x00 + b[2] = 0x00 + b[3] = 0x00 +} + +type wireguareConn struct { + first bool + leaveSize int32 + + conn net.PacketConn + header *wireguare + + readBuf []byte + readMutex sync.Mutex + writeBuf []byte + writeMutex sync.Mutex +} + +func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + conn := &wireguareConn{ + first: first, + leaveSize: leaveSize, + + conn: raw, + header: &wireguare{}, + } + + if first { + conn.readBuf = make([]byte, 8192) + conn.writeBuf = make([]byte, 8192) + } + + return conn, nil +} + +func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *wireguareConn) Size() int32 { + return c.header.Size() +} + +func (c *wireguareConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if c.first { + c.readMutex.Lock() + + n, addr, err = c.conn.ReadFrom(c.readBuf) + if err != nil { + c.readMutex.Unlock() + return n, addr, err + } + + if n < int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + if len(p) < n-int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + copy(p, c.readBuf[c.Size():n]) + + c.readMutex.Unlock() + return n - int(c.Size()), addr, err + } + + n, addr, err = c.conn.ReadFrom(p) + if err != nil { + return n, addr, err + } + + if n < int(c.Size()) { + return 0, addr, errors.New("header").Base(io.ErrShortBuffer) + } + + copy(p, p[c.Size():n]) + + return n - int(c.Size()), addr, err +} + +func (c *wireguareConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.first { + if c.leaveSize+c.Size()+int32(len(p)) > 8192 { + return 0, errors.New("too many masks") + } + + c.writeMutex.Lock() + + n = copy(c.writeBuf[c.leaveSize+c.Size():], p) + n += int(c.leaveSize) + int(c.Size()) + + c.header.Serialize(c.writeBuf[c.leaveSize : c.leaveSize+c.Size()]) + + nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) + + if err != nil { + c.writeMutex.Unlock() + return 0, err + } + + if nn != n { + c.writeMutex.Unlock() + return 0, errors.New("nn != n") + } + + c.writeMutex.Unlock() + return len(p), nil + } + + c.header.Serialize(p[c.leaveSize : c.leaveSize+c.Size()]) + + return c.conn.WriteTo(p, addr) +} + +func (c *wireguareConn) Close() error { + return c.conn.Close() +} + +func (c *wireguareConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *wireguareConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *wireguareConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *wireguareConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} diff --git a/transport/internet/finalmask/mkcp/aes128gcm/aes128gcm_test.go b/transport/internet/finalmask/mkcp/aes128gcm/aes128gcm_test.go new file mode 100644 index 00000000..ab849c7f --- /dev/null +++ b/transport/internet/finalmask/mkcp/aes128gcm/aes128gcm_test.go @@ -0,0 +1,70 @@ +package aes128gcm_test + +import ( + "crypto/rand" + "crypto/sha256" + "testing" + + "github.com/amnezia-vpn/amnezia-xray-core/common/crypto" + "github.com/stretchr/testify/assert" +) + +func TestAes128GcmSealInPlace(t *testing.T) { + hashedPsk := sha256.Sum256([]byte("psk")) + aead := crypto.NewAesGcm(hashedPsk[:16]) + + text := []byte("0123456789012") + buf := make([]byte, 8192) + + nonceSize := aead.NonceSize() + nonce := buf[:nonceSize] + rand.Read(nonce) + copy(buf[nonceSize:], text) + plaintext := buf[nonceSize : nonceSize+len(text)] + + sealed := aead.Seal(nil, nonce, plaintext, nil) + + _ = aead.Seal(plaintext[:0], nonce, plaintext, nil) + + assert.Equal(t, sealed, buf[nonceSize:nonceSize+aead.Overhead()+len(text)]) +} + +func encrypted(plain []byte) ([]byte, []byte) { + hashedPsk := sha256.Sum256([]byte("psk")) + aead := crypto.NewAesGcm(hashedPsk[:16]) + + nonce := make([]byte, 12) + rand.Read(nonce) + + return nonce, aead.Seal(nil, nonce, plain, nil) +} + +func TestAes128GcmOpenInPlace(t *testing.T) { + a, b := encrypted([]byte("0123456789012")) + buf := make([]byte, 8192) + copy(buf, a) + copy(buf[len(a):], b) + + hashedPsk := sha256.Sum256([]byte("psk")) + aead := crypto.NewAesGcm(hashedPsk[:16]) + + nonceSize := aead.NonceSize() + nonce := buf[:nonceSize] + ciphertext := buf[nonceSize : nonceSize+len(b)] + + opened, _ := aead.Open(nil, nonce, ciphertext, nil) + _, _ = aead.Open(ciphertext[:0], nonce, ciphertext, nil) + + assert.Equal(t, opened, ciphertext[:len(ciphertext)-aead.Overhead()]) +} + +func TestAes128GcmBounce(t *testing.T) { + hashedPsk := sha256.Sum256([]byte("psk")) + aead := crypto.NewAesGcm(hashedPsk[:16]) + buf := make([]byte, aead.NonceSize()+aead.Overhead()) + for i := 0; i < 1000; i++ { + _, _ = rand.Read(buf) + _, err := aead.Open(buf[aead.NonceSize():aead.NonceSize()], buf[:aead.NonceSize()], buf[aead.NonceSize():], nil) + assert.NotEqual(t, err, nil) + } +} diff --git a/transport/internet/finalmask/mkcp/aes128gcm/config.go b/transport/internet/finalmask/mkcp/aes128gcm/config.go new file mode 100644 index 00000000..595dd4ee --- /dev/null +++ b/transport/internet/finalmask/mkcp/aes128gcm/config.go @@ -0,0 +1,16 @@ +package aes128gcm + +import ( + "net" +) + +func (c *Config) UDP() { +} + +func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnServer(c, raw, first, leaveSize) +} diff --git a/transport/internet/finalmask/mkcp/aes128gcm/config.pb.go b/transport/internet/finalmask/mkcp/aes128gcm/config.pb.go new file mode 100644 index 00000000..2fbb3a40 --- /dev/null +++ b/transport/internet/finalmask/mkcp/aes128gcm/config.pb.go @@ -0,0 +1,123 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v3.21.12 +// source: transport/internet/finalmask/mkcp/aes128gcm/config.proto + +package aes128gcm + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + Password string `protobuf:"bytes,1,opt,name=password,proto3" json:"password,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDescGZIP(), []int{0} +} + +func (x *Config) GetPassword() string { + if x != nil { + return x.Password + } + return "" +} + +var File_transport_internet_finalmask_mkcp_aes128gcm_config_proto protoreflect.FileDescriptor + +const file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDesc = "" + + "\n" + + "8transport/internet/finalmask/mkcp/aes128gcm/config.proto\x120xray.transport.internet.finalmask.mkcp.aes128gcm\"$\n" + + "\x06Config\x12\x1a\n" + + "\bpassword\x18\x01 \x01(\tR\bpasswordB\xc1\x01\n" + + "4com.xray.transport.internet.finalmask.mkcp.aes128gcmP\x01ZTgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/aes128gcm\xaa\x020Xray.Transport.Internet.Finalmask.Mkcp.Aes128Gcmb\x06proto3" + +var ( + file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDescOnce sync.Once + file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDescData []byte +) + +func file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDescGZIP() []byte { + file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDescOnce.Do(func() { + file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDesc), len(file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDesc))) + }) + return file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDescData +} + +var file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_goTypes = []any{ + (*Config)(nil), // 0: xray.transport.internet.finalmask.mkcp.aes128gcm.Config +} +var file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_init() } +func file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_init() { + if File_transport_internet_finalmask_mkcp_aes128gcm_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDesc), len(file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_goTypes, + DependencyIndexes: file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_depIdxs, + MessageInfos: file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_msgTypes, + }.Build() + File_transport_internet_finalmask_mkcp_aes128gcm_config_proto = out.File + file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_goTypes = nil + file_transport_internet_finalmask_mkcp_aes128gcm_config_proto_depIdxs = nil +} diff --git a/transport/internet/finalmask/mkcp/aes128gcm/config.proto b/transport/internet/finalmask/mkcp/aes128gcm/config.proto new file mode 100644 index 00000000..f0c9a439 --- /dev/null +++ b/transport/internet/finalmask/mkcp/aes128gcm/config.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package xray.transport.internet.finalmask.mkcp.aes128gcm; +option csharp_namespace = "Xray.Transport.Internet.Finalmask.Mkcp.Aes128Gcm"; +option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/aes128gcm"; +option java_package = "com.xray.transport.internet.finalmask.mkcp.aes128gcm"; +option java_multiple_files = true; + +message Config { + string password = 1; +} diff --git a/transport/internet/finalmask/mkcp/aes128gcm/conn.go b/transport/internet/finalmask/mkcp/aes128gcm/conn.go new file mode 100644 index 00000000..d8b88a20 --- /dev/null +++ b/transport/internet/finalmask/mkcp/aes128gcm/conn.go @@ -0,0 +1,174 @@ +package aes128gcm + +import ( + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "io" + "net" + sync "sync" + "time" + + "github.com/amnezia-vpn/amnezia-xray-core/common" + "github.com/amnezia-vpn/amnezia-xray-core/common/crypto" + "github.com/amnezia-vpn/amnezia-xray-core/common/errors" +) + +type aes128gcmConn struct { + first bool + leaveSize int32 + + conn net.PacketConn + aead cipher.AEAD + + readBuf []byte + readMutex sync.Mutex + writeBuf []byte + writeMutex sync.Mutex +} + +func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + hashedPsk := sha256.Sum256([]byte(c.Password)) + + conn := &aes128gcmConn{ + first: first, + leaveSize: leaveSize, + + conn: raw, + aead: crypto.NewAesGcm(hashedPsk[:16]), + } + + if first { + conn.readBuf = make([]byte, 8192) + conn.writeBuf = make([]byte, 8192) + } + + return conn, nil +} + +func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *aes128gcmConn) Size() int32 { + return int32(c.aead.NonceSize()) + int32(c.aead.Overhead()) +} + +func (c *aes128gcmConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if c.first { + c.readMutex.Lock() + + n, addr, err = c.conn.ReadFrom(c.readBuf) + if err != nil { + c.readMutex.Unlock() + return n, addr, err + } + + if n < int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("aead").Base(io.ErrShortBuffer) + } + + if len(p) < n-int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("aead").Base(io.ErrShortBuffer) + } + + nonceSize := c.aead.NonceSize() + nonce := c.readBuf[:nonceSize] + ciphertext := c.readBuf[nonceSize:n] + _, err = c.aead.Open(p[:0], nonce, ciphertext, nil) + if err != nil { + c.readMutex.Unlock() + return 0, addr, errors.New("aead open").Base(err) + } + + c.readMutex.Unlock() + return n - int(c.Size()), addr, nil + } + + n, addr, err = c.conn.ReadFrom(p) + if err != nil { + return n, addr, err + } + + if n < int(c.Size()) { + return 0, addr, errors.New("aead").Base(io.ErrShortBuffer) + } + + nonceSize := c.aead.NonceSize() + nonce := p[:nonceSize] + ciphertext := p[nonceSize:n] + _, err = c.aead.Open(ciphertext[:0], nonce, ciphertext, nil) + if err != nil { + return 0, addr, errors.New("aead open").Base(err) + } + copy(p, p[nonceSize:n-c.aead.Overhead()]) + + return n - int(c.Size()), addr, nil +} + +func (c *aes128gcmConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.first { + if c.leaveSize+c.Size()+int32(len(p)) > 8192 { + return 0, errors.New("too many masks") + } + + c.writeMutex.Lock() + + n = copy(c.writeBuf[c.leaveSize+int32(c.aead.NonceSize()):], p) + // n = copy(c.writeBuf[c.leaveSize+c.Size():], p) + n += int(c.leaveSize) + int(c.Size()) + + nonceSize := c.aead.NonceSize() + nonce := c.writeBuf[c.leaveSize : c.leaveSize+int32(nonceSize)] + common.Must2(rand.Read(nonce)) + // copy(c.writeBuf[c.leaveSize+int32(nonceSize):], c.writeBuf[c.leaveSize+c.Size():n]) + plaintext := c.writeBuf[c.leaveSize+int32(nonceSize) : n-c.aead.Overhead()] + _ = c.aead.Seal(plaintext[:0], nonce, plaintext, nil) + + nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) + + if err != nil { + c.writeMutex.Unlock() + return 0, err + } + + if nn != n { + c.writeMutex.Unlock() + return 0, errors.New("nn != n") + } + + c.writeMutex.Unlock() + return len(p), nil + } + + nonceSize := c.aead.NonceSize() + nonce := p[c.leaveSize : c.leaveSize+int32(nonceSize)] + common.Must2(rand.Read(nonce)) + copy(p[c.leaveSize+int32(nonceSize):], p[c.leaveSize+c.Size():]) + plaintext := p[c.leaveSize+int32(nonceSize) : len(p)-c.aead.Overhead()] + _ = c.aead.Seal(plaintext[:0], nonce, plaintext, nil) + + return c.conn.WriteTo(p, addr) +} + +func (c *aes128gcmConn) Close() error { + return c.conn.Close() +} + +func (c *aes128gcmConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *aes128gcmConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *aes128gcmConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *aes128gcmConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} diff --git a/transport/internet/finalmask/mkcp/original/config.go b/transport/internet/finalmask/mkcp/original/config.go new file mode 100644 index 00000000..026c979d --- /dev/null +++ b/transport/internet/finalmask/mkcp/original/config.go @@ -0,0 +1,16 @@ +package original + +import ( + "net" +) + +func (c *Config) UDP() { +} + +func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnServer(c, raw, first, leaveSize) +} diff --git a/transport/internet/finalmask/mkcp/original/config.pb.go b/transport/internet/finalmask/mkcp/original/config.pb.go new file mode 100644 index 00000000..51861b6c --- /dev/null +++ b/transport/internet/finalmask/mkcp/original/config.pb.go @@ -0,0 +1,114 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v3.21.12 +// source: transport/internet/finalmask/mkcp/original/config.proto + +package original + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_finalmask_mkcp_original_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_mkcp_original_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_mkcp_original_config_proto_rawDescGZIP(), []int{0} +} + +var File_transport_internet_finalmask_mkcp_original_config_proto protoreflect.FileDescriptor + +const file_transport_internet_finalmask_mkcp_original_config_proto_rawDesc = "" + + "\n" + + "7transport/internet/finalmask/mkcp/original/config.proto\x12/xray.transport.internet.finalmask.mkcp.original\"\b\n" + + "\x06ConfigB\xbe\x01\n" + + "3com.xray.transport.internet.finalmask.mkcp.originalP\x01ZSgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/original\xaa\x02/Xray.Transport.Internet.Finalmask.Mkcp.Originalb\x06proto3" + +var ( + file_transport_internet_finalmask_mkcp_original_config_proto_rawDescOnce sync.Once + file_transport_internet_finalmask_mkcp_original_config_proto_rawDescData []byte +) + +func file_transport_internet_finalmask_mkcp_original_config_proto_rawDescGZIP() []byte { + file_transport_internet_finalmask_mkcp_original_config_proto_rawDescOnce.Do(func() { + file_transport_internet_finalmask_mkcp_original_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_mkcp_original_config_proto_rawDesc), len(file_transport_internet_finalmask_mkcp_original_config_proto_rawDesc))) + }) + return file_transport_internet_finalmask_mkcp_original_config_proto_rawDescData +} + +var file_transport_internet_finalmask_mkcp_original_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_finalmask_mkcp_original_config_proto_goTypes = []any{ + (*Config)(nil), // 0: xray.transport.internet.finalmask.mkcp.original.Config +} +var file_transport_internet_finalmask_mkcp_original_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_transport_internet_finalmask_mkcp_original_config_proto_init() } +func file_transport_internet_finalmask_mkcp_original_config_proto_init() { + if File_transport_internet_finalmask_mkcp_original_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_mkcp_original_config_proto_rawDesc), len(file_transport_internet_finalmask_mkcp_original_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_finalmask_mkcp_original_config_proto_goTypes, + DependencyIndexes: file_transport_internet_finalmask_mkcp_original_config_proto_depIdxs, + MessageInfos: file_transport_internet_finalmask_mkcp_original_config_proto_msgTypes, + }.Build() + File_transport_internet_finalmask_mkcp_original_config_proto = out.File + file_transport_internet_finalmask_mkcp_original_config_proto_goTypes = nil + file_transport_internet_finalmask_mkcp_original_config_proto_depIdxs = nil +} diff --git a/transport/internet/finalmask/mkcp/original/config.proto b/transport/internet/finalmask/mkcp/original/config.proto new file mode 100644 index 00000000..7e7e914a --- /dev/null +++ b/transport/internet/finalmask/mkcp/original/config.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package xray.transport.internet.finalmask.mkcp.original; +option csharp_namespace = "Xray.Transport.Internet.Finalmask.Mkcp.Original"; +option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/original"; +option java_package = "com.xray.transport.internet.finalmask.mkcp.original"; +option java_multiple_files = true; + +message Config {} diff --git a/transport/internet/finalmask/mkcp/original/conn.go b/transport/internet/finalmask/mkcp/original/conn.go new file mode 100644 index 00000000..d7fb628b --- /dev/null +++ b/transport/internet/finalmask/mkcp/original/conn.go @@ -0,0 +1,225 @@ +package original + +import ( + "crypto/cipher" + "encoding/binary" + "hash/fnv" + "io" + "net" + sync "sync" + "time" + + "github.com/amnezia-vpn/amnezia-xray-core/common" + "github.com/amnezia-vpn/amnezia-xray-core/common/errors" +) + +type simple struct{} + +func NewSimple() *simple { + return &simple{} +} + +func (*simple) NonceSize() int { + return 0 +} + +func (*simple) Overhead() int { + return 6 +} + +func (a *simple) Seal(dst, nonce, plain, extra []byte) []byte { + dst = append(dst, 0, 0, 0, 0, 0, 0) + binary.BigEndian.PutUint16(dst[4:], uint16(len(plain))) + dst = append(dst, plain...) + + fnvHash := fnv.New32a() + common.Must2(fnvHash.Write(dst[4:])) + fnvHash.Sum(dst[:0]) + + dstLen := len(dst) + xtra := 4 - dstLen%4 + if xtra != 4 { + dst = append(dst, make([]byte, xtra)...) + } + xorfwd(dst) + if xtra != 4 { + dst = dst[:dstLen] + } + return dst +} + +func (a *simple) Open(dst, nonce, cipherText, extra []byte) ([]byte, error) { + dst = append(dst, cipherText...) + dstLen := len(dst) + xtra := 4 - dstLen%4 + if xtra != 4 { + dst = append(dst, make([]byte, xtra)...) + } + xorbkd(dst) + if xtra != 4 { + dst = dst[:dstLen] + } + + fnvHash := fnv.New32a() + common.Must2(fnvHash.Write(dst[4:])) + if binary.BigEndian.Uint32(dst[:4]) != fnvHash.Sum32() { + return nil, errors.New("invalid auth") + } + + length := binary.BigEndian.Uint16(dst[4:6]) + if len(dst)-6 != int(length) { + return nil, errors.New("invalid auth") + } + + return dst[6:], nil +} + +type simpleConn struct { + first bool + leaveSize int32 + + conn net.PacketConn + aead cipher.AEAD + + readBuf []byte + readMutex sync.Mutex + writeBuf []byte + writeMutex sync.Mutex +} + +func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + conn := &simpleConn{ + first: first, + leaveSize: leaveSize, + + conn: raw, + aead: &simple{}, + } + + if first { + conn.readBuf = make([]byte, 8192) + conn.writeBuf = make([]byte, 8192) + } + + return conn, nil +} + +func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *simpleConn) Size() int32 { + return int32(c.aead.NonceSize()) + int32(c.aead.Overhead()) +} + +func (c *simpleConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if c.first { + c.readMutex.Lock() + + n, addr, err = c.conn.ReadFrom(c.readBuf) + if err != nil { + c.readMutex.Unlock() + return n, addr, err + } + + if n < int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("aead").Base(io.ErrShortBuffer) + } + + if len(p) < n-int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("aead").Base(io.ErrShortBuffer) + } + + ciphertext := c.readBuf[:n] + opened, err := c.aead.Open(nil, nil, ciphertext, nil) + if err != nil { + c.readMutex.Unlock() + return 0, addr, errors.New("aead open").Base(err) + } + + copy(p, opened) + + c.readMutex.Unlock() + return n - int(c.Size()), addr, nil + } + + n, addr, err = c.conn.ReadFrom(p) + if err != nil { + return n, addr, err + } + + if n < int(c.Size()) { + return 0, addr, errors.New("aead").Base(io.ErrShortBuffer) + } + + ciphertext := p[:n] + opened, err := c.aead.Open(nil, nil, ciphertext, nil) + if err != nil { + c.readMutex.Unlock() + return 0, addr, errors.New("aead open").Base(err) + } + + copy(p, opened) + + return n - int(c.Size()), addr, nil +} + +func (c *simpleConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.first { + if c.leaveSize+c.Size()+int32(len(p)) > 8192 { + return 0, errors.New("too many masks") + } + + c.writeMutex.Lock() + + n = copy(c.writeBuf[c.leaveSize+c.Size():], p) + n += int(c.leaveSize) + int(c.Size()) + + plaintext := c.writeBuf[c.leaveSize+c.Size() : n] + sealed := c.aead.Seal(nil, nil, plaintext, nil) + copy(c.writeBuf[c.leaveSize:], sealed) + + nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) + + if err != nil { + c.writeMutex.Unlock() + return 0, err + } + + if nn != n { + c.writeMutex.Unlock() + return 0, errors.New("nn != n") + } + + c.writeMutex.Unlock() + return len(p), nil + } + + plaintext := p[c.leaveSize+c.Size():] + sealed := c.aead.Seal(nil, nil, plaintext, nil) + copy(p[c.leaveSize:], sealed) + + return c.conn.WriteTo(p, addr) +} + +func (c *simpleConn) Close() error { + return c.conn.Close() +} + +func (c *simpleConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *simpleConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *simpleConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *simpleConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} diff --git a/transport/internet/finalmask/mkcp/original/simple_test.go b/transport/internet/finalmask/mkcp/original/simple_test.go new file mode 100644 index 00000000..43eb1cb4 --- /dev/null +++ b/transport/internet/finalmask/mkcp/original/simple_test.go @@ -0,0 +1,19 @@ +package original_test + +import ( + "crypto/rand" + "testing" + + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/original" + "github.com/stretchr/testify/assert" +) + +func TestOriginalBounce(t *testing.T) { + aead := original.NewSimple() + buf := make([]byte, aead.NonceSize()+aead.Overhead()) + for i := 0; i < 1000; i++ { + _, _ = rand.Read(buf) + _, err := aead.Open(buf[:0], nil, buf, nil) + assert.NotEqual(t, err, nil) + } +} diff --git a/transport/internet/kcp/xor.go b/transport/internet/finalmask/mkcp/original/xor.go similarity index 95% rename from transport/internet/kcp/xor.go rename to transport/internet/finalmask/mkcp/original/xor.go index 233a2729..b2a06179 100644 --- a/transport/internet/kcp/xor.go +++ b/transport/internet/finalmask/mkcp/original/xor.go @@ -1,7 +1,7 @@ //go:build !amd64 // +build !amd64 -package kcp +package original // xorfwd performs XOR forwards in words, x[i] ^= x[i-4], i from 0 to len func xorfwd(x []byte) { diff --git a/transport/internet/kcp/xor_amd64.go b/transport/internet/finalmask/mkcp/original/xor_amd64.go similarity index 81% rename from transport/internet/kcp/xor_amd64.go rename to transport/internet/finalmask/mkcp/original/xor_amd64.go index 94a4dfc8..7352ace9 100644 --- a/transport/internet/kcp/xor_amd64.go +++ b/transport/internet/finalmask/mkcp/original/xor_amd64.go @@ -1,4 +1,4 @@ -package kcp +package original //go:noescape func xorfwd(x []byte) diff --git a/transport/internet/kcp/xor_amd64.s b/transport/internet/finalmask/mkcp/original/xor_amd64.s similarity index 100% rename from transport/internet/kcp/xor_amd64.s rename to transport/internet/finalmask/mkcp/original/xor_amd64.s diff --git a/transport/internet/finalmask/salamander/config.go b/transport/internet/finalmask/salamander/config.go index e557f31e..c864e270 100644 --- a/transport/internet/finalmask/salamander/config.go +++ b/transport/internet/finalmask/salamander/config.go @@ -2,41 +2,15 @@ package salamander import ( "net" - - "github.com/amnezia-vpn/amnezia-xray-core/common/errors" - "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/salamander/obfs" ) func (c *Config) UDP() { } -func (c *Config) WrapConnClient(raw net.Conn) (net.Conn, error) { - return raw, nil +func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) } -func (c *Config) WrapConnServer(raw net.Conn) (net.Conn, error) { - return raw, nil -} - -func (c *Config) WrapPacketConnClient(raw net.PacketConn) (net.PacketConn, error) { - ob, err := obfs.NewSalamanderObfuscator([]byte(c.Password)) - if err != nil { - return nil, errors.New("salamander err").Base(err) - } - return obfs.WrapPacketConn(raw, ob), nil -} - -func (c *Config) WrapPacketConnServer(raw net.PacketConn) (net.PacketConn, error) { - ob, err := obfs.NewSalamanderObfuscator([]byte(c.Password)) - if err != nil { - return nil, errors.New("salamander err").Base(err) - } - return obfs.WrapPacketConn(raw, ob), nil -} - -func (c *Config) Size() int { - return 0 -} - -func (c *Config) Serialize([]byte) { +func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnServer(c, raw, first, leaveSize) } diff --git a/transport/internet/finalmask/salamander/config.pb.go b/transport/internet/finalmask/salamander/config.pb.go index bc89f020..ffa1ac41 100644 --- a/transport/internet/finalmask/salamander/config.pb.go +++ b/transport/internet/finalmask/salamander/config.pb.go @@ -69,10 +69,10 @@ var File_transport_internet_finalmask_salamander_config_proto protoreflect.FileD const file_transport_internet_finalmask_salamander_config_proto_rawDesc = "" + "\n" + - "4transport/internet/finalmask/salamander/config.proto\x12*xray.transport.internet.udpmask.salamander\"$\n" + + "4transport/internet/finalmask/salamander/config.proto\x12,xray.transport.internet.finalmask.salamander\"$\n" + "\x06Config\x12\x1a\n" + - "\bpassword\x18\x01 \x01(\tR\bpasswordB\xaf\x01\n" + - ".com.xray.transport.internet.udpmask.salamanderP\x01ZNgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/udpmask/salamander\xaa\x02*Xray.Transport.Internet.Udpmask.Salamanderb\x06proto3" + "\bpassword\x18\x01 \x01(\tR\bpasswordB\xb5\x01\n" + + "0com.xray.transport.internet.finalmask.salamanderP\x01ZPgithub.com/amneiza-vpn/amnezia-xray-core/transport/internet/finalmask/salamander\xaa\x02,Xray.Transport.Internet.Finalmask.Salamanderb\x06proto3" var ( file_transport_internet_finalmask_salamander_config_proto_rawDescOnce sync.Once @@ -88,7 +88,7 @@ func file_transport_internet_finalmask_salamander_config_proto_rawDescGZIP() []b var file_transport_internet_finalmask_salamander_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) var file_transport_internet_finalmask_salamander_config_proto_goTypes = []any{ - (*Config)(nil), // 0: xray.transport.internet.udpmask.salamander.Config + (*Config)(nil), // 0: xray.transport.internet.finalmask.salamander.Config } var file_transport_internet_finalmask_salamander_config_proto_depIdxs = []int32{ 0, // [0:0] is the sub-list for method output_type diff --git a/transport/internet/finalmask/salamander/config.proto b/transport/internet/finalmask/salamander/config.proto index c0c46fbc..8e46bb0e 100644 --- a/transport/internet/finalmask/salamander/config.proto +++ b/transport/internet/finalmask/salamander/config.proto @@ -1,9 +1,9 @@ syntax = "proto3"; -package xray.transport.internet.udpmask.salamander; -option csharp_namespace = "Xray.Transport.Internet.Udpmask.Salamander"; -option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/udpmask/salamander"; -option java_package = "com.xray.transport.internet.udpmask.salamander"; +package xray.transport.internet.finalmask.salamander; +option csharp_namespace = "Xray.Transport.Internet.Finalmask.Salamander"; +option go_package = "github.com/amneiza-vpn/amnezia-xray-core/transport/internet/finalmask/salamander"; +option java_package = "com.xray.transport.internet.finalmask.salamander"; option java_multiple_files = true; message Config { diff --git a/transport/internet/finalmask/salamander/conn.go b/transport/internet/finalmask/salamander/conn.go new file mode 100644 index 00000000..df11d619 --- /dev/null +++ b/transport/internet/finalmask/salamander/conn.go @@ -0,0 +1,147 @@ +package salamander + +import ( + "io" + "net" + "sync" + "time" + + "github.com/amnezia-vpn/amnezia-xray-core/common/errors" +) + +type obfsPacketConn struct { + first bool + leaveSize int32 + + conn net.PacketConn + obfs *SalamanderObfuscator + + readBuf []byte + readMutex sync.Mutex + writeBuf []byte + writeMutex sync.Mutex +} + +func NewConnClient(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + ob, err := NewSalamanderObfuscator([]byte(c.Password)) + if err != nil { + return nil, errors.New("salamander err").Base(err) + } + + conn := &obfsPacketConn{ + first: first, + leaveSize: leaveSize, + + conn: raw, + obfs: ob, + } + + if first { + conn.readBuf = make([]byte, 8192) + conn.writeBuf = make([]byte, 8192) + } + + return conn, nil +} + +func NewConnServer(c *Config, raw net.PacketConn, first bool, leaveSize int32) (net.PacketConn, error) { + return NewConnClient(c, raw, first, leaveSize) +} + +func (c *obfsPacketConn) Size() int32 { + return smSaltLen +} + +func (c *obfsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if c.first { + c.readMutex.Lock() + + n, addr, err = c.conn.ReadFrom(c.readBuf) + if err != nil { + c.readMutex.Unlock() + return n, addr, err + } + + if n < int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("salamander").Base(io.ErrShortBuffer) + } + + if len(p) < n-int(c.Size()) { + c.readMutex.Unlock() + return 0, addr, errors.New("salamander").Base(io.ErrShortBuffer) + } + + c.obfs.Deobfuscate(c.readBuf[:n], p) + + c.readMutex.Unlock() + return n - int(c.Size()), addr, err + } + + n, addr, err = c.conn.ReadFrom(p) + if err != nil { + return n, addr, err + } + + if n < int(c.Size()) { + return 0, addr, errors.New("salamander").Base(io.ErrShortBuffer) + } + + c.obfs.Deobfuscate(p[:n], p) + + return n - int(c.Size()), addr, err +} + +func (c *obfsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if c.first { + if c.leaveSize+c.Size()+int32(len(p)) > 8192 { + return 0, errors.New("too many masks") + } + + c.writeMutex.Lock() + + n = copy(c.writeBuf[c.leaveSize+c.Size():], p) + n += int(c.leaveSize) + int(c.Size()) + + c.obfs.Obfuscate(c.writeBuf[c.leaveSize+c.Size():n], c.writeBuf[c.leaveSize:n]) + + nn, err := c.conn.WriteTo(c.writeBuf[:n], addr) + + if err != nil { + c.writeMutex.Unlock() + return 0, err + } + + if nn != n { + c.writeMutex.Unlock() + return 0, errors.New("nn != n") + } + + c.writeMutex.Unlock() + return len(p), nil + } + + c.obfs.Obfuscate(p[c.leaveSize+c.Size():], p[c.leaveSize:]) + + return c.conn.WriteTo(p, addr) +} + +func (c *obfsPacketConn) Close() error { + return c.conn.Close() +} + +func (c *obfsPacketConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *obfsPacketConn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *obfsPacketConn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *obfsPacketConn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} diff --git a/transport/internet/finalmask/salamander/obfs/conn.go b/transport/internet/finalmask/salamander/obfs/conn.go deleted file mode 100644 index 6b97592e..00000000 --- a/transport/internet/finalmask/salamander/obfs/conn.go +++ /dev/null @@ -1,121 +0,0 @@ -package obfs - -import ( - "net" - "sync" - "syscall" - "time" -) - -const udpBufferSize = 2048 // QUIC packets are at most 1500 bytes long, so 2k should be more than enough - -// Obfuscator is the interface that wraps the Obfuscate and Deobfuscate methods. -// Both methods return the number of bytes written to out. -// If a packet is not valid, the methods should return 0. -type Obfuscator interface { - Obfuscate(in, out []byte) int - Deobfuscate(in, out []byte) int -} - -var _ net.PacketConn = (*obfsPacketConn)(nil) - -type obfsPacketConn struct { - Conn net.PacketConn - Obfs Obfuscator - - readBuf []byte - readMutex sync.Mutex - writeBuf []byte - writeMutex sync.Mutex -} - -// obfsPacketConnUDP is a special case of obfsPacketConn that uses a UDPConn -// as the underlying connection. We pass additional methods to quic-go to -// enable UDP-specific optimizations. -type obfsPacketConnUDP struct { - *obfsPacketConn - UDPConn *net.UDPConn -} - -// WrapPacketConn enables obfuscation on a net.PacketConn. -// The obfuscation is transparent to the caller - the n bytes returned by -// ReadFrom and WriteTo are the number of original bytes, not after -// obfuscation/deobfuscation. -func WrapPacketConn(conn net.PacketConn, obfs Obfuscator) net.PacketConn { - opc := &obfsPacketConn{ - Conn: conn, - Obfs: obfs, - readBuf: make([]byte, udpBufferSize), - writeBuf: make([]byte, udpBufferSize), - } - if udpConn, ok := conn.(*net.UDPConn); ok { - return &obfsPacketConnUDP{ - obfsPacketConn: opc, - UDPConn: udpConn, - } - } else { - return opc - } -} - -func (c *obfsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - for { - c.readMutex.Lock() - n, addr, err = c.Conn.ReadFrom(c.readBuf) - if n <= 0 { - c.readMutex.Unlock() - return n, addr, err - } - n = c.Obfs.Deobfuscate(c.readBuf[:n], p) - c.readMutex.Unlock() - if n > 0 || err != nil { - return n, addr, err - } - // Invalid packet, try again - } -} - -func (c *obfsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - c.writeMutex.Lock() - nn := c.Obfs.Obfuscate(p, c.writeBuf) - _, err = c.Conn.WriteTo(c.writeBuf[:nn], addr) - c.writeMutex.Unlock() - if err == nil { - n = len(p) - } - return n, err -} - -func (c *obfsPacketConn) Close() error { - return c.Conn.Close() -} - -func (c *obfsPacketConn) LocalAddr() net.Addr { - return c.Conn.LocalAddr() -} - -func (c *obfsPacketConn) SetDeadline(t time.Time) error { - return c.Conn.SetDeadline(t) -} - -func (c *obfsPacketConn) SetReadDeadline(t time.Time) error { - return c.Conn.SetReadDeadline(t) -} - -func (c *obfsPacketConn) SetWriteDeadline(t time.Time) error { - return c.Conn.SetWriteDeadline(t) -} - -// UDP-specific methods below - -func (c *obfsPacketConnUDP) SetReadBuffer(bytes int) error { - return c.UDPConn.SetReadBuffer(bytes) -} - -func (c *obfsPacketConnUDP) SetWriteBuffer(bytes int) error { - return c.UDPConn.SetWriteBuffer(bytes) -} - -func (c *obfsPacketConnUDP) SyscallConn() (syscall.RawConn, error) { - return c.UDPConn.SyscallConn() -} diff --git a/transport/internet/finalmask/salamander/obfs/salamander_test.go b/transport/internet/finalmask/salamander/obfs/salamander_test.go deleted file mode 100644 index 85eafdcc..00000000 --- a/transport/internet/finalmask/salamander/obfs/salamander_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package obfs - -import ( - "crypto/rand" - "testing" - - "github.com/stretchr/testify/assert" -) - -func BenchmarkSalamanderObfuscator_Obfuscate(b *testing.B) { - o, _ := NewSalamanderObfuscator([]byte("average_password")) - in := make([]byte, 1200) - _, _ = rand.Read(in) - out := make([]byte, 2048) - b.ResetTimer() - for i := 0; i < b.N; i++ { - o.Obfuscate(in, out) - } -} - -func BenchmarkSalamanderObfuscator_Deobfuscate(b *testing.B) { - o, _ := NewSalamanderObfuscator([]byte("average_password")) - in := make([]byte, 1200) - _, _ = rand.Read(in) - out := make([]byte, 2048) - b.ResetTimer() - for i := 0; i < b.N; i++ { - o.Deobfuscate(in, out) - } -} - -func TestSalamanderObfuscator(t *testing.T) { - o, _ := NewSalamanderObfuscator([]byte("average_password")) - in := make([]byte, 1200) - oOut := make([]byte, 2048) - dOut := make([]byte, 2048) - for i := 0; i < 1000; i++ { - _, _ = rand.Read(in) - n := o.Obfuscate(in, oOut) - assert.Equal(t, len(in)+smSaltLen, n) - n = o.Deobfuscate(oOut[:n], dOut) - assert.Equal(t, len(in), n) - assert.Equal(t, in, dOut[:n]) - } -} diff --git a/transport/internet/finalmask/salamander/obfs/salamander.go b/transport/internet/finalmask/salamander/salamander.go similarity index 95% rename from transport/internet/finalmask/salamander/obfs/salamander.go rename to transport/internet/finalmask/salamander/salamander.go index 50a3ce26..86d92dcd 100644 --- a/transport/internet/finalmask/salamander/obfs/salamander.go +++ b/transport/internet/finalmask/salamander/salamander.go @@ -1,4 +1,4 @@ -package obfs +package salamander import ( "fmt" @@ -15,8 +15,6 @@ const ( smKeyLen = blake2b.Size256 ) -var _ Obfuscator = (*SalamanderObfuscator)(nil) - var ErrPSKTooShort = fmt.Errorf("PSK must be at least %d bytes", smPSKMinLen) // SalamanderObfuscator is an obfuscator that obfuscates each packet with diff --git a/transport/internet/finalmask/salamander/salamander_test.go b/transport/internet/finalmask/salamander/salamander_test.go new file mode 100644 index 00000000..049a7433 --- /dev/null +++ b/transport/internet/finalmask/salamander/salamander_test.go @@ -0,0 +1,81 @@ +package salamander_test + +import ( + "crypto/rand" + "testing" + + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/salamander" + "github.com/stretchr/testify/assert" +) + +const ( + smSaltLen = 8 +) + +func BenchmarkSalamanderObfuscator_Obfuscate(b *testing.B) { + o, _ := salamander.NewSalamanderObfuscator([]byte("average_password")) + in := make([]byte, 1200) + _, _ = rand.Read(in) + out := make([]byte, 2048) + b.ResetTimer() + for i := 0; i < b.N; i++ { + o.Obfuscate(in, out) + } +} + +func BenchmarkSalamanderObfuscator_Deobfuscate(b *testing.B) { + o, _ := salamander.NewSalamanderObfuscator([]byte("average_password")) + in := make([]byte, 1200) + _, _ = rand.Read(in) + out := make([]byte, 2048) + b.ResetTimer() + for i := 0; i < b.N; i++ { + o.Deobfuscate(in, out) + } +} + +func TestSalamanderObfuscator(t *testing.T) { + o, _ := salamander.NewSalamanderObfuscator([]byte("average_password")) + in := make([]byte, 1200) + oOut := make([]byte, 2048) + dOut := make([]byte, 2048) + for i := 0; i < 1000; i++ { + _, _ = rand.Read(in) + n := o.Obfuscate(in, oOut) + assert.Equal(t, len(in)+smSaltLen, n) + n = o.Deobfuscate(oOut[:n], dOut) + assert.Equal(t, len(in), n) + assert.Equal(t, in, dOut[:n]) + } +} + +func TestSalamanderInPlace(t *testing.T) { + o, _ := salamander.NewSalamanderObfuscator([]byte("average_password")) + + in := make([]byte, 1200) + out := make([]byte, 2048) + _, _ = rand.Read(in) + o.Obfuscate(in, out) + + out2 := make([]byte, 2048) + copy(out2[smSaltLen:], in) + o.Obfuscate(out2[smSaltLen:], out2) + + dOut := make([]byte, 2048) + o.Deobfuscate(out, dOut) + + o.Deobfuscate(out2, out2) + + assert.Equal(t, in, dOut[:1200]) + assert.Equal(t, in, out2[:1200]) +} + +func TestSalamanderBounce(t *testing.T) { + o, _ := salamander.NewSalamanderObfuscator([]byte("average_password")) + buf := make([]byte, 8) + for i := 0; i < 1000; i++ { + _, _ = rand.Read(buf) + n := o.Deobfuscate(buf, buf) + assert.Equal(t, 0, n) + } +} diff --git a/transport/internet/finalmask/udp_test.go b/transport/internet/finalmask/udp_test.go new file mode 100644 index 00000000..57fcbaa0 --- /dev/null +++ b/transport/internet/finalmask/udp_test.go @@ -0,0 +1,129 @@ +package finalmask_test + +import ( + "bytes" + "net" + "testing" + "time" + + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/dns" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/srtp" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/utp" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/wechat" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/header/wireguard" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/aes128gcm" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/mkcp/original" + "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/salamander" +) + +func mustSendRecv( + t *testing.T, + from net.PacketConn, + to net.PacketConn, + msg []byte, +) { + t.Helper() + + go func() { + _, err := from.WriteTo(msg, to.LocalAddr()) + if err != nil { + t.Error(err) + } + }() + + buf := make([]byte, 1024) + n, _, err := to.ReadFrom(buf) + if err != nil { + t.Fatal(err) + } + + if n != len(msg) { + t.Fatalf("unexpected size: %d", n) + } + + if !bytes.Equal(buf[:n], msg) { + t.Fatalf("unexpected data") + } +} + +type layerMask struct { + name string + mask finalmask.Udpmask +} + +func TestPacketConnReadWrite(t *testing.T) { + cases := []layerMask{ + { + name: "aes128gcm", + mask: &aes128gcm.Config{Password: "123"}, + }, + { + name: "original", + mask: &original.Config{}, + }, + { + name: "dns", + mask: &dns.Config{Domain: "www.baidu.com"}, + }, + { + name: "srtp", + mask: &srtp.Config{}, + }, + { + name: "utp", + mask: &utp.Config{}, + }, + { + name: "wechat", + mask: &wechat.Config{}, + }, + { + name: "wireguard", + mask: &wireguard.Config{}, + }, + { + name: "salamander", + mask: &salamander.Config{Password: "1234"}, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + mask := c.mask + + maskManager := finalmask.NewUdpmaskManager([]finalmask.Udpmask{mask, mask}) + + client, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer client.Close() + + client, err = maskManager.WrapPacketConnClient(client) + if err != nil { + t.Fatal(err) + } + + server, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer server.Close() + + server, err = maskManager.WrapPacketConnServer(server) + if err != nil { + t.Fatal(err) + } + + _ = client.SetDeadline(time.Now().Add(time.Second)) + _ = server.SetDeadline(time.Now().Add(time.Second)) + + mustSendRecv(t, client, server, []byte("client -> server")) + mustSendRecv(t, server, client, []byte("server -> client")) + + mustSendRecv(t, client, server, []byte{}) + mustSendRecv(t, server, client, []byte{}) + }) + } +} diff --git a/transport/internet/finalmask/xdns/client.go b/transport/internet/finalmask/xdns/client.go new file mode 100644 index 00000000..e6819e1b --- /dev/null +++ b/transport/internet/finalmask/xdns/client.go @@ -0,0 +1,373 @@ +package xdns + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/base32" + "encoding/binary" + "io" + "net" + "sync" + "time" + + "github.com/amnezia-vpn/amnezia-xray-core/common/errors" +) + +const ( + numPadding = 3 + numPaddingForPoll = 8 + initPollDelay = 500 * time.Millisecond + maxPollDelay = 10 * time.Second + pollDelayMultiplier = 2.0 + pollLimit = 16 +) + +var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding) + +type packet struct { + p []byte + addr net.Addr +} + +type xdnsConnClient struct { + conn net.PacketConn + + clientID []byte + domain Name + + pollChan chan struct{} + readQueue chan *packet + writeQueue chan *packet + + closed bool + mutex sync.Mutex +} + +func NewConnClient(c *Config, raw net.PacketConn, end bool) (net.PacketConn, error) { + if !end { + return nil, errors.New("xdns requires being at the outermost level") + } + + domain, err := ParseName(c.Domain) + if err != nil { + return nil, err + } + + conn := &xdnsConnClient{ + conn: raw, + + clientID: make([]byte, 8), + domain: domain, + + pollChan: make(chan struct{}, pollLimit), + readQueue: make(chan *packet, 128), + writeQueue: make(chan *packet, 128), + } + + rand.Read(conn.clientID) + + go conn.recvLoop() + go conn.sendLoop() + + return conn, nil +} + +func (c *xdnsConnClient) recvLoop() { + for { + if c.closed { + break + } + + var buf [4096]byte + + n, addr, err := c.conn.ReadFrom(buf[:]) + if err != nil { + continue + } + + resp, err := MessageFromWireFormat(buf[:n]) + if err != nil { + continue + } + + payload := dnsResponsePayload(&resp, c.domain) + + r := bytes.NewReader(payload) + anyPacket := false + for { + p, err := nextPacket(r) + if err != nil { + break + } + anyPacket = true + + buf := make([]byte, len(p)) + copy(buf, p) + select { + case c.readQueue <- &packet{ + p: buf, + addr: addr, + }: + default: + } + } + + if anyPacket { + select { + case c.pollChan <- struct{}{}: + default: + } + } + } + + close(c.pollChan) + close(c.readQueue) +} + +func (c *xdnsConnClient) sendLoop() { + var addr net.Addr + + pollDelay := initPollDelay + pollTimer := time.NewTimer(pollDelay) + for { + var p *packet + pollTimerExpired := false + + select { + case p = <-c.writeQueue: + default: + select { + case p = <-c.writeQueue: + case <-c.pollChan: + case <-pollTimer.C: + pollTimerExpired = true + } + } + + if p != nil { + addr = p.addr + + select { + case <-c.pollChan: + default: + } + } else if addr != nil { + encoded, _ := encode(nil, c.clientID, c.domain) + p = &packet{ + p: encoded, + addr: addr, + } + } + + if pollTimerExpired { + pollDelay = time.Duration(float64(pollDelay) * pollDelayMultiplier) + if pollDelay > maxPollDelay { + pollDelay = maxPollDelay + } + } else { + if !pollTimer.Stop() { + <-pollTimer.C + } + pollDelay = initPollDelay + } + pollTimer.Reset(pollDelay) + + if c.closed { + return + } + + if p != nil { + _, _ = c.conn.WriteTo(p.p, p.addr) + } + } +} + +func (c *xdnsConnClient) Size() int32 { + return 0 +} + +func (c *xdnsConnClient) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + packet, ok := <-c.readQueue + if !ok { + return 0, nil, io.EOF + } + n = copy(p, packet.p) + if n != len(packet.p) { + return 0, nil, io.ErrShortBuffer + } + return n, packet.addr, nil +} + +func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return 0, errors.New("xdns closed") + } + + encoded, err := encode(p, c.clientID, c.domain) + if err != nil { + errors.LogDebug(context.Background(), "xdns encode err ", err) + return 0, errors.New("xdns encode").Base(err) + } + + select { + case c.writeQueue <- &packet{ + p: encoded, + addr: addr, + }: + return len(p), nil + default: + return 0, errors.New("xdns queue full") + } +} + +func (c *xdnsConnClient) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return nil + } + + c.closed = true + close(c.writeQueue) + + return c.conn.Close() +} + +func (c *xdnsConnClient) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *xdnsConnClient) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *xdnsConnClient) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *xdnsConnClient) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func encode(p []byte, clientID []byte, domain Name) ([]byte, error) { + var decoded []byte + { + if len(p) >= 224 { + return nil, errors.New("too long") + } + var buf bytes.Buffer + buf.Write(clientID[:]) + n := numPadding + if len(p) == 0 { + n = numPaddingForPoll + } + buf.WriteByte(byte(224 + n)) + _, _ = io.CopyN(&buf, rand.Reader, int64(n)) + if len(p) > 0 { + buf.WriteByte(byte(len(p))) + buf.Write(p) + } + decoded = buf.Bytes() + } + + encoded := make([]byte, base32Encoding.EncodedLen(len(decoded))) + base32Encoding.Encode(encoded, decoded) + encoded = bytes.ToLower(encoded) + labels := chunks(encoded, 63) + labels = append(labels, domain...) + name, err := NewName(labels) + if err != nil { + return nil, err + } + + var id uint16 + _ = binary.Read(rand.Reader, binary.BigEndian, &id) + query := &Message{ + ID: id, + Flags: 0x0100, + Question: []Question{ + { + Name: name, + Type: RRTypeTXT, + Class: ClassIN, + }, + }, + Additional: []RR{ + { + Name: Name{}, + Type: RRTypeOPT, + Class: 4096, + TTL: 0, + Data: []byte{}, + }, + }, + } + + buf, err := query.WireFormat() + if err != nil { + return nil, err + } + + return buf, nil +} + +func chunks(p []byte, n int) [][]byte { + var result [][]byte + for len(p) > 0 { + sz := len(p) + if sz > n { + sz = n + } + result = append(result, p[:sz]) + p = p[sz:] + } + return result +} + +func nextPacket(r *bytes.Reader) ([]byte, error) { + var n uint16 + err := binary.Read(r, binary.BigEndian, &n) + if err != nil { + return nil, err + } + p := make([]byte, n) + _, err = io.ReadFull(r, p) + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return p, err +} + +func dnsResponsePayload(resp *Message, domain Name) []byte { + if resp.Flags&0x8000 != 0x8000 { + return nil + } + if resp.Flags&0x000f != RcodeNoError { + return nil + } + + if len(resp.Answer) != 1 { + return nil + } + answer := resp.Answer[0] + + _, ok := answer.Name.TrimSuffix(domain) + if !ok { + return nil + } + + if answer.Type != RRTypeTXT { + return nil + } + payload, err := DecodeRDataTXT(answer.Data) + if err != nil { + return nil + } + + return payload +} diff --git a/transport/internet/finalmask/xdns/config.go b/transport/internet/finalmask/xdns/config.go new file mode 100644 index 00000000..cf30902a --- /dev/null +++ b/transport/internet/finalmask/xdns/config.go @@ -0,0 +1,16 @@ +package xdns + +import ( + "net" +) + +func (c *Config) UDP() { +} + +func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnClient(c, raw, end) +} + +func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnServer(c, raw, end) +} diff --git a/transport/internet/finalmask/xdns/config.pb.go b/transport/internet/finalmask/xdns/config.pb.go new file mode 100644 index 00000000..10eec8d8 --- /dev/null +++ b/transport/internet/finalmask/xdns/config.pb.go @@ -0,0 +1,123 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v3.21.12 +// source: transport/internet/finalmask/xdns/config.proto + +package xdns + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_finalmask_xdns_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_xdns_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_xdns_config_proto_rawDescGZIP(), []int{0} +} + +func (x *Config) GetDomain() string { + if x != nil { + return x.Domain + } + return "" +} + +var File_transport_internet_finalmask_xdns_config_proto protoreflect.FileDescriptor + +const file_transport_internet_finalmask_xdns_config_proto_rawDesc = "" + + "\n" + + ".transport/internet/finalmask/xdns/config.proto\x12&xray.transport.internet.finalmask.xdns\" \n" + + "\x06Config\x12\x16\n" + + "\x06domain\x18\x01 \x01(\tR\x06domainB\xa3\x01\n" + + "*com.xray.transport.internet.finalmask.xdnsP\x01ZJgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/xdns\xaa\x02&Xray.Transport.Internet.Finalmask.Xdnsb\x06proto3" + +var ( + file_transport_internet_finalmask_xdns_config_proto_rawDescOnce sync.Once + file_transport_internet_finalmask_xdns_config_proto_rawDescData []byte +) + +func file_transport_internet_finalmask_xdns_config_proto_rawDescGZIP() []byte { + file_transport_internet_finalmask_xdns_config_proto_rawDescOnce.Do(func() { + file_transport_internet_finalmask_xdns_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_xdns_config_proto_rawDesc), len(file_transport_internet_finalmask_xdns_config_proto_rawDesc))) + }) + return file_transport_internet_finalmask_xdns_config_proto_rawDescData +} + +var file_transport_internet_finalmask_xdns_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_finalmask_xdns_config_proto_goTypes = []any{ + (*Config)(nil), // 0: xray.transport.internet.finalmask.xdns.Config +} +var file_transport_internet_finalmask_xdns_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_transport_internet_finalmask_xdns_config_proto_init() } +func file_transport_internet_finalmask_xdns_config_proto_init() { + if File_transport_internet_finalmask_xdns_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_xdns_config_proto_rawDesc), len(file_transport_internet_finalmask_xdns_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_finalmask_xdns_config_proto_goTypes, + DependencyIndexes: file_transport_internet_finalmask_xdns_config_proto_depIdxs, + MessageInfos: file_transport_internet_finalmask_xdns_config_proto_msgTypes, + }.Build() + File_transport_internet_finalmask_xdns_config_proto = out.File + file_transport_internet_finalmask_xdns_config_proto_goTypes = nil + file_transport_internet_finalmask_xdns_config_proto_depIdxs = nil +} diff --git a/transport/internet/finalmask/xdns/config.proto b/transport/internet/finalmask/xdns/config.proto new file mode 100644 index 00000000..2a3a1637 --- /dev/null +++ b/transport/internet/finalmask/xdns/config.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +package xray.transport.internet.finalmask.xdns; +option csharp_namespace = "Xray.Transport.Internet.Finalmask.Xdns"; +option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/xdns"; +option java_package = "com.xray.transport.internet.finalmask.xdns"; +option java_multiple_files = true; + +message Config { + string domain = 1; +} + diff --git a/transport/internet/finalmask/xdns/dns.go b/transport/internet/finalmask/xdns/dns.go new file mode 100644 index 00000000..4cdac7cd --- /dev/null +++ b/transport/internet/finalmask/xdns/dns.go @@ -0,0 +1,575 @@ +// Package dns deals with encoding and decoding DNS wire format. +package xdns + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "strings" +) + +// The maximum number of DNS name compression pointers we are willing to follow. +// Without something like this, infinite loops are possible. +const compressionPointerLimit = 10 + +var ( + // ErrZeroLengthLabel is the error returned for names that contain a + // zero-length label, like "example..com". + ErrZeroLengthLabel = errors.New("name contains a zero-length label") + + // ErrLabelTooLong is the error returned for labels that are longer than + // 63 octets. + ErrLabelTooLong = errors.New("name contains a label longer than 63 octets") + + // ErrNameTooLong is the error returned for names whose encoded + // representation is longer than 255 octets. + ErrNameTooLong = errors.New("name is longer than 255 octets") + + // ErrReservedLabelType is the error returned when reading a label type + // prefix whose two most significant bits are not 00 or 11. + ErrReservedLabelType = errors.New("reserved label type") + + // ErrTooManyPointers is the error returned when reading a compressed + // name that has too many compression pointers. + ErrTooManyPointers = errors.New("too many compression pointers") + + // ErrTrailingBytes is the error returned when bytes remain in the parse + // buffer after parsing a message. + ErrTrailingBytes = errors.New("trailing bytes after message") + + // ErrIntegerOverflow is the error returned when trying to encode an + // integer greater than 65535 into a 16-bit field. + ErrIntegerOverflow = errors.New("integer overflow") +) + +const ( + // https://tools.ietf.org/html/rfc1035#section-3.2.2 + RRTypeTXT = 16 + // https://tools.ietf.org/html/rfc6891#section-6.1.1 + RRTypeOPT = 41 + + // https://tools.ietf.org/html/rfc1035#section-3.2.4 + ClassIN = 1 + + // https://tools.ietf.org/html/rfc1035#section-4.1.1 + RcodeNoError = 0 // a.k.a. NOERROR + RcodeFormatError = 1 // a.k.a. FORMERR + RcodeNameError = 3 // a.k.a. NXDOMAIN + RcodeNotImplemented = 4 // a.k.a. NOTIMPL + // https://tools.ietf.org/html/rfc6891#section-9 + ExtendedRcodeBadVers = 16 // a.k.a. BADVERS +) + +// Name represents a domain name, a sequence of labels each of which is 63 +// octets or less in length. +// +// https://tools.ietf.org/html/rfc1035#section-3.1 +type Name [][]byte + +// NewName returns a Name from a slice of labels, after checking the labels for +// validity. Does not include a zero-length label at the end of the slice. +func NewName(labels [][]byte) (Name, error) { + name := Name(labels) + // https://tools.ietf.org/html/rfc1035#section-2.3.4 + // Various objects and parameters in the DNS have size limits. + // labels 63 octets or less + // names 255 octets or less + for _, label := range labels { + if len(label) == 0 { + return nil, ErrZeroLengthLabel + } + if len(label) > 63 { + return nil, ErrLabelTooLong + } + } + // Check the total length. + builder := newMessageBuilder() + builder.WriteName(name) + if len(builder.Bytes()) > 255 { + return nil, ErrNameTooLong + } + return name, nil +} + +// ParseName returns a new Name from a string of labels separated by dots, after +// checking the name for validity. A single dot at the end of the string is +// ignored. +func ParseName(s string) (Name, error) { + b := bytes.TrimSuffix([]byte(s), []byte(".")) + if len(b) == 0 { + // bytes.Split(b, ".") would return [""] in this case + return NewName([][]byte{}) + } else { + return NewName(bytes.Split(b, []byte("."))) + } +} + +// String returns a reversible string representation of name. Labels are +// separated by dots, and any bytes in a label that are outside the set +// [0-9A-Za-z-] are replaced with a \xXX hex escape sequence. +func (name Name) String() string { + if len(name) == 0 { + return "." + } + + var buf strings.Builder + for i, label := range name { + if i > 0 { + buf.WriteByte('.') + } + for _, b := range label { + if b == '-' || + ('0' <= b && b <= '9') || + ('A' <= b && b <= 'Z') || + ('a' <= b && b <= 'z') { + buf.WriteByte(b) + } else { + fmt.Fprintf(&buf, "\\x%02x", b) + } + } + } + return buf.String() +} + +// TrimSuffix returns a Name with the given suffix removed, if it was present. +// The second return value indicates whether the suffix was present. If the +// suffix was not present, the first return value is nil. +func (name Name) TrimSuffix(suffix Name) (Name, bool) { + if len(name) < len(suffix) { + return nil, false + } + split := len(name) - len(suffix) + fore, aft := name[:split], name[split:] + for i := 0; i < len(aft); i++ { + if !bytes.Equal(bytes.ToLower(aft[i]), bytes.ToLower(suffix[i])) { + return nil, false + } + } + return fore, true +} + +// Message represents a DNS message. +// +// https://tools.ietf.org/html/rfc1035#section-4.1 +type Message struct { + ID uint16 + Flags uint16 + + Question []Question + Answer []RR + Authority []RR + Additional []RR +} + +// Opcode extracts the OPCODE part of the Flags field. +// +// https://tools.ietf.org/html/rfc1035#section-4.1.1 +func (message *Message) Opcode() uint16 { + return (message.Flags >> 11) & 0xf +} + +// Rcode extracts the RCODE part of the Flags field. +// +// https://tools.ietf.org/html/rfc1035#section-4.1.1 +func (message *Message) Rcode() uint16 { + return message.Flags & 0x000f +} + +// Question represents an entry in the question section of a message. +// +// https://tools.ietf.org/html/rfc1035#section-4.1.2 +type Question struct { + Name Name + Type uint16 + Class uint16 +} + +// RR represents a resource record. +// +// https://tools.ietf.org/html/rfc1035#section-4.1.3 +type RR struct { + Name Name + Type uint16 + Class uint16 + TTL uint32 + Data []byte +} + +// readName parses a DNS name from r. It leaves r positioned just after the +// parsed name. +func readName(r io.ReadSeeker) (Name, error) { + var labels [][]byte + // We limit the number of compression pointers we are willing to follow. + numPointers := 0 + // If we followed any compression pointers, we must finally seek to just + // past the first pointer. + var seekTo int64 +loop: + for { + var labelType byte + err := binary.Read(r, binary.BigEndian, &labelType) + if err != nil { + return nil, err + } + + switch labelType & 0xc0 { + case 0x00: + // This is an ordinary label. + // https://tools.ietf.org/html/rfc1035#section-3.1 + length := int(labelType & 0x3f) + if length == 0 { + break loop + } + label := make([]byte, length) + _, err := io.ReadFull(r, label) + if err != nil { + return nil, err + } + labels = append(labels, label) + case 0xc0: + // This is a compression pointer. + // https://tools.ietf.org/html/rfc1035#section-4.1.4 + upper := labelType & 0x3f + var lower byte + err := binary.Read(r, binary.BigEndian, &lower) + if err != nil { + return nil, err + } + offset := (uint16(upper) << 8) | uint16(lower) + + if numPointers == 0 { + // The first time we encounter a pointer, + // remember our position so we can seek back to + // it when done. + seekTo, err = r.Seek(0, io.SeekCurrent) + if err != nil { + return nil, err + } + } + numPointers++ + if numPointers > compressionPointerLimit { + return nil, ErrTooManyPointers + } + + // Follow the pointer and continue. + _, err = r.Seek(int64(offset), io.SeekStart) + if err != nil { + return nil, err + } + default: + // "The 10 and 01 combinations are reserved for future + // use." + return nil, ErrReservedLabelType + } + } + // If we followed any pointers, then seek back to just after the first + // one. + if numPointers > 0 { + _, err := r.Seek(seekTo, io.SeekStart) + if err != nil { + return nil, err + } + } + return NewName(labels) +} + +// readQuestion parses one entry from the Question section. It leaves r +// positioned just after the parsed entry. +// +// https://tools.ietf.org/html/rfc1035#section-4.1.2 +func readQuestion(r io.ReadSeeker) (Question, error) { + var question Question + var err error + question.Name, err = readName(r) + if err != nil { + return question, err + } + for _, ptr := range []*uint16{&question.Type, &question.Class} { + err := binary.Read(r, binary.BigEndian, ptr) + if err != nil { + return question, err + } + } + + return question, nil +} + +// readRR parses one resource record. It leaves r positioned just after the +// parsed resource record. +// +// https://tools.ietf.org/html/rfc1035#section-4.1.3 +func readRR(r io.ReadSeeker) (RR, error) { + var rr RR + var err error + rr.Name, err = readName(r) + if err != nil { + return rr, err + } + for _, ptr := range []*uint16{&rr.Type, &rr.Class} { + err := binary.Read(r, binary.BigEndian, ptr) + if err != nil { + return rr, err + } + } + err = binary.Read(r, binary.BigEndian, &rr.TTL) + if err != nil { + return rr, err + } + var rdLength uint16 + err = binary.Read(r, binary.BigEndian, &rdLength) + if err != nil { + return rr, err + } + rr.Data = make([]byte, rdLength) + _, err = io.ReadFull(r, rr.Data) + if err != nil { + return rr, err + } + + return rr, nil +} + +// readMessage parses a complete DNS message. It leaves r positioned just after +// the parsed message. +func readMessage(r io.ReadSeeker) (Message, error) { + var message Message + + // Header section + // https://tools.ietf.org/html/rfc1035#section-4.1.1 + var qdCount, anCount, nsCount, arCount uint16 + for _, ptr := range []*uint16{ + &message.ID, &message.Flags, + &qdCount, &anCount, &nsCount, &arCount, + } { + err := binary.Read(r, binary.BigEndian, ptr) + if err != nil { + return message, err + } + } + + // Question section + // https://tools.ietf.org/html/rfc1035#section-4.1.2 + for i := 0; i < int(qdCount); i++ { + question, err := readQuestion(r) + if err != nil { + return message, err + } + message.Question = append(message.Question, question) + } + + // Answer, Authority, and Additional sections + // https://tools.ietf.org/html/rfc1035#section-4.1.3 + for _, rec := range []struct { + ptr *[]RR + count uint16 + }{ + {&message.Answer, anCount}, + {&message.Authority, nsCount}, + {&message.Additional, arCount}, + } { + for i := 0; i < int(rec.count); i++ { + rr, err := readRR(r) + if err != nil { + return message, err + } + *rec.ptr = append(*rec.ptr, rr) + } + } + + return message, nil +} + +// MessageFromWireFormat parses a message from buf and returns a Message object. +// It returns ErrTrailingBytes if there are bytes remaining in buf after parsing +// is done. +func MessageFromWireFormat(buf []byte) (Message, error) { + r := bytes.NewReader(buf) + message, err := readMessage(r) + if err == io.EOF { + err = io.ErrUnexpectedEOF + } else if err == nil { + // Check for trailing bytes. + _, err = r.ReadByte() + if err == io.EOF { + err = nil + } else if err == nil { + err = ErrTrailingBytes + } + } + return message, err +} + +// messageBuilder manages the state of serializing a DNS message. Its main +// function is to keep track of names already written for the purpose of name +// compression. +type messageBuilder struct { + w bytes.Buffer + nameCache map[string]int +} + +// newMessageBuilder creates a new messageBuilder with an empty name cache. +func newMessageBuilder() *messageBuilder { + return &messageBuilder{ + nameCache: make(map[string]int), + } +} + +// Bytes returns the serialized DNS message as a slice of bytes. +func (builder *messageBuilder) Bytes() []byte { + return builder.w.Bytes() +} + +// WriteName appends name to the in-progress messageBuilder, employing +// compression pointers to previously written names if possible. +func (builder *messageBuilder) WriteName(name Name) { + // https://tools.ietf.org/html/rfc1035#section-3.1 + for i := range name { + // Has this suffix already been encoded in the message? + if ptr, ok := builder.nameCache[name[i:].String()]; ok && ptr&0x3fff == ptr { + // If so, we can write a compression pointer. + binary.Write(&builder.w, binary.BigEndian, uint16(0xc000|ptr)) + return + } + // Not cached; we must encode this label verbatim. Store a cache + // entry pointing to the beginning of it. + builder.nameCache[name[i:].String()] = builder.w.Len() + length := len(name[i]) + if length == 0 || length > 63 { + panic(length) + } + builder.w.WriteByte(byte(length)) + builder.w.Write(name[i]) + } + builder.w.WriteByte(0) +} + +// WriteQuestion appends a Question section entry to the in-progress +// messageBuilder. +func (builder *messageBuilder) WriteQuestion(question *Question) { + // https://tools.ietf.org/html/rfc1035#section-4.1.2 + builder.WriteName(question.Name) + binary.Write(&builder.w, binary.BigEndian, question.Type) + binary.Write(&builder.w, binary.BigEndian, question.Class) +} + +// WriteRR appends a resource record to the in-progress messageBuilder. It +// returns ErrIntegerOverflow if the length of rr.Data does not fit in 16 bits. +func (builder *messageBuilder) WriteRR(rr *RR) error { + // https://tools.ietf.org/html/rfc1035#section-4.1.3 + builder.WriteName(rr.Name) + binary.Write(&builder.w, binary.BigEndian, rr.Type) + binary.Write(&builder.w, binary.BigEndian, rr.Class) + binary.Write(&builder.w, binary.BigEndian, rr.TTL) + rdLength := uint16(len(rr.Data)) + if int(rdLength) != len(rr.Data) { + return ErrIntegerOverflow + } + binary.Write(&builder.w, binary.BigEndian, rdLength) + builder.w.Write(rr.Data) + return nil +} + +// WriteMessage appends a complete DNS message to the in-progress +// messageBuilder. It returns ErrIntegerOverflow if the number of entries in any +// section, or the length of the data in any resource record, does not fit in 16 +// bits. +func (builder *messageBuilder) WriteMessage(message *Message) error { + // Header section + // https://tools.ietf.org/html/rfc1035#section-4.1.1 + binary.Write(&builder.w, binary.BigEndian, message.ID) + binary.Write(&builder.w, binary.BigEndian, message.Flags) + for _, count := range []int{ + len(message.Question), + len(message.Answer), + len(message.Authority), + len(message.Additional), + } { + count16 := uint16(count) + if int(count16) != count { + return ErrIntegerOverflow + } + binary.Write(&builder.w, binary.BigEndian, count16) + } + + // Question section + // https://tools.ietf.org/html/rfc1035#section-4.1.2 + for _, question := range message.Question { + builder.WriteQuestion(&question) + } + + // Answer, Authority, and Additional sections + // https://tools.ietf.org/html/rfc1035#section-4.1.3 + for _, rrs := range [][]RR{message.Answer, message.Authority, message.Additional} { + for _, rr := range rrs { + err := builder.WriteRR(&rr) + if err != nil { + return err + } + } + } + + return nil +} + +// WireFormat encodes a Message as a slice of bytes in DNS wire format. It +// returns ErrIntegerOverflow if the number of entries in any section, or the +// length of the data in any resource record, does not fit in 16 bits. +func (message *Message) WireFormat() ([]byte, error) { + builder := newMessageBuilder() + err := builder.WriteMessage(message) + if err != nil { + return nil, err + } + return builder.Bytes(), nil +} + +// DecodeRDataTXT decodes TXT-DATA (as found in the RDATA for a resource record +// with TYPE=TXT) as a raw byte slice, by concatenating all the +// s it contains. +// +// https://tools.ietf.org/html/rfc1035#section-3.3.14 +func DecodeRDataTXT(p []byte) ([]byte, error) { + var buf bytes.Buffer + for { + if len(p) == 0 { + return nil, io.ErrUnexpectedEOF + } + n := int(p[0]) + p = p[1:] + if len(p) < n { + return nil, io.ErrUnexpectedEOF + } + buf.Write(p[:n]) + p = p[n:] + if len(p) == 0 { + break + } + } + return buf.Bytes(), nil +} + +// EncodeRDataTXT encodes a slice of bytes as TXT-DATA, as appropriate for the +// RDATA of a resource record with TYPE=TXT. No length restriction is enforced +// here; that must be checked at a higher level. +// +// https://tools.ietf.org/html/rfc1035#section-3.3.14 +func EncodeRDataTXT(p []byte) []byte { + // https://tools.ietf.org/html/rfc1035#section-3.3 + // https://tools.ietf.org/html/rfc1035#section-3.3.14 + // TXT data is a sequence of one or more s, where + // is a length octet followed by that number of + // octets. + var buf bytes.Buffer + for len(p) > 255 { + buf.WriteByte(255) + buf.Write(p[:255]) + p = p[255:] + } + // Must write here, even if len(p) == 0, because it's "*one or more* + // s". + buf.WriteByte(byte(len(p))) + buf.Write(p) + return buf.Bytes() +} diff --git a/transport/internet/finalmask/xdns/dns_test.go b/transport/internet/finalmask/xdns/dns_test.go new file mode 100644 index 00000000..b07f57b9 --- /dev/null +++ b/transport/internet/finalmask/xdns/dns_test.go @@ -0,0 +1,592 @@ +package xdns + +import ( + "bytes" + "fmt" + "io" + "strconv" + "strings" + "testing" +) + +func namesEqual(a, b Name) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !bytes.Equal(a[i], b[i]) { + return false + } + } + return true +} + +func TestName(t *testing.T) { + for _, test := range []struct { + labels [][]byte + err error + s string + }{ + {[][]byte{}, nil, "."}, + {[][]byte{[]byte("test")}, nil, "test"}, + {[][]byte{[]byte("a"), []byte("b"), []byte("c")}, nil, "a.b.c"}, + + {[][]byte{{}}, ErrZeroLengthLabel, ""}, + {[][]byte{[]byte("a"), {}, []byte("c")}, ErrZeroLengthLabel, ""}, + + // 63 octets. + {[][]byte{[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE")}, nil, + "0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"}, + // 64 octets. + {[][]byte{[]byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDEF")}, ErrLabelTooLong, ""}, + + // 64+64+64+62 octets. + {[][]byte{ + []byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"), + []byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"), + []byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"), + []byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABC"), + }, nil, + "0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE.0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE.0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE.0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABC"}, + // 64+64+64+63 octets. + {[][]byte{ + []byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"), + []byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"), + []byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCDE"), + []byte("0123456789abcdef0123456789ABCDEF0123456789abcdef0123456789ABCD"), + }, ErrNameTooLong, ""}, + // 127 one-octet labels. + {[][]byte{ + {'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'}, + {'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'}, + {'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'}, + {'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'}, + {'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'}, + {'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'}, + {'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'}, + {'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, + }, nil, + "0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E.F.0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E.F.0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E.F.0.1.2.3.4.5.6.7.8.9.a.b.c.d.e.f.0.1.2.3.4.5.6.7.8.9.A.B.C.D.E"}, + // 128 one-octet labels. + {[][]byte{ + {'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'}, + {'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'}, + {'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'}, + {'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'}, + {'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'}, + {'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'}, + {'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'}, + {'0'}, {'1'}, {'2'}, {'3'}, {'4'}, {'5'}, {'6'}, {'7'}, {'8'}, {'9'}, {'A'}, {'B'}, {'C'}, {'D'}, {'E'}, {'F'}, + }, ErrNameTooLong, ""}, + } { + // Test that NewName returns proper error codes, and otherwise + // returns an equal slice of labels. + name, err := NewName(test.labels) + if err != test.err || (err == nil && !namesEqual(name, test.labels)) { + t.Errorf("%+q returned (%+q, %v), expected (%+q, %v)", + test.labels, name, err, test.labels, test.err) + continue + } + if test.err != nil { + continue + } + + // Test that the string version of the name comes out as + // expected. + s := name.String() + if s != test.s { + t.Errorf("%+q became string %+q, expected %+q", test.labels, s, test.s) + continue + } + + // Test that parsing from a string back to a Name results in the + // original slice of labels. + name, err = ParseName(s) + if err != nil || !namesEqual(name, test.labels) { + t.Errorf("%+q parsing %+q returned (%+q, %v), expected (%+q, %v)", + test.labels, s, name, err, test.labels, nil) + continue + } + // A trailing dot should be ignored. + if !strings.HasSuffix(s, ".") { + dotName, dotErr := ParseName(s + ".") + if dotErr != err || !namesEqual(dotName, name) { + t.Errorf("%+q parsing %+q returned (%+q, %v), expected (%+q, %v)", + test.labels, s+".", dotName, dotErr, name, err) + continue + } + } + } +} + +func TestParseName(t *testing.T) { + for _, test := range []struct { + s string + name Name + err error + }{ + // This case can't be tested by TestName above because String + // will never produce "" (it produces "." instead). + {"", [][]byte{}, nil}, + } { + name, err := ParseName(test.s) + if err != test.err || (err == nil && !namesEqual(name, test.name)) { + t.Errorf("%+q returned (%+q, %v), expected (%+q, %v)", + test.s, name, err, test.name, test.err) + continue + } + } +} + +func unescapeString(s string) ([][]byte, error) { + if s == "." { + return [][]byte{}, nil + } + + var result [][]byte + for _, label := range strings.Split(s, ".") { + var buf bytes.Buffer + i := 0 + for i < len(label) { + switch label[i] { + case '\\': + if i+3 >= len(label) { + return nil, fmt.Errorf("truncated escape sequence at index %v", i) + } + if label[i+1] != 'x' { + return nil, fmt.Errorf("malformed escape sequence at index %v", i) + } + b, err := strconv.ParseUint(string(label[i+2:i+4]), 16, 8) + if err != nil { + return nil, fmt.Errorf("malformed hex sequence at index %v", i+2) + } + buf.WriteByte(byte(b)) + i += 4 + default: + buf.WriteByte(label[i]) + i++ + } + } + result = append(result, buf.Bytes()) + } + return result, nil +} + +func TestNameString(t *testing.T) { + for _, test := range []struct { + name Name + s string + }{ + {[][]byte{}, "."}, + {[][]byte{[]byte("\x00"), []byte("a.b"), []byte("c\nd\\")}, "\\x00.a\\x2eb.c\\x0ad\\x5c"}, + {[][]byte{ + []byte("\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>"), + []byte("?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}"), + []byte("~\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc"), + []byte("\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb"), + []byte("\xfc\xfd\xfe\xff"), + }, "\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\x09\\x0a\\x0b\\x0c\\x0d\\x0e\\x0f\\x10\\x11\\x12\\x13\\x14\\x15\\x16\\x17\\x18\\x19\\x1a\\x1b\\x1c\\x1d\\x1e\\x1f\\x20\\x21\\x22\\x23\\x24\\x25\\x26\\x27\\x28\\x29\\x2a\\x2b\\x2c-\\x2e\\x2f0123456789\\x3a\\x3b\\x3c\\x3d\\x3e.\\x3f\\x40ABCDEFGHIJKLMNOPQRSTUVWXYZ\\x5b\\x5c\\x5d\\x5e\\x5f\\x60abcdefghijklmnopqrstuvwxyz\\x7b\\x7c\\x7d.\\x7e\\x7f\\x80\\x81\\x82\\x83\\x84\\x85\\x86\\x87\\x88\\x89\\x8a\\x8b\\x8c\\x8d\\x8e\\x8f\\x90\\x91\\x92\\x93\\x94\\x95\\x96\\x97\\x98\\x99\\x9a\\x9b\\x9c\\x9d\\x9e\\x9f\\xa0\\xa1\\xa2\\xa3\\xa4\\xa5\\xa6\\xa7\\xa8\\xa9\\xaa\\xab\\xac\\xad\\xae\\xaf\\xb0\\xb1\\xb2\\xb3\\xb4\\xb5\\xb6\\xb7\\xb8\\xb9\\xba\\xbb\\xbc.\\xbd\\xbe\\xbf\\xc0\\xc1\\xc2\\xc3\\xc4\\xc5\\xc6\\xc7\\xc8\\xc9\\xca\\xcb\\xcc\\xcd\\xce\\xcf\\xd0\\xd1\\xd2\\xd3\\xd4\\xd5\\xd6\\xd7\\xd8\\xd9\\xda\\xdb\\xdc\\xdd\\xde\\xdf\\xe0\\xe1\\xe2\\xe3\\xe4\\xe5\\xe6\\xe7\\xe8\\xe9\\xea\\xeb\\xec\\xed\\xee\\xef\\xf0\\xf1\\xf2\\xf3\\xf4\\xf5\\xf6\\xf7\\xf8\\xf9\\xfa\\xfb.\\xfc\\xfd\\xfe\\xff"}, + } { + s := test.name.String() + if s != test.s { + t.Errorf("%+q escaped to %+q, expected %+q", test.name, s, test.s) + continue + } + unescaped, err := unescapeString(s) + if err != nil { + t.Errorf("%+q unescaping %+q resulted in error %v", test.name, s, err) + continue + } + if !namesEqual(Name(unescaped), test.name) { + t.Errorf("%+q roundtripped through %+q to %+q", test.name, s, unescaped) + continue + } + } +} + +func TestNameTrimSuffix(t *testing.T) { + for _, test := range []struct { + name, suffix string + trimmed string + ok bool + }{ + {"", "", ".", true}, + {".", ".", ".", true}, + {"abc", "", "abc", true}, + {"abc", ".", "abc", true}, + {"", "abc", ".", false}, + {".", "abc", ".", false}, + {"example.com", "com", "example", true}, + {"example.com", "net", ".", false}, + {"example.com", "example.com", ".", true}, + {"example.com", "test.com", ".", false}, + {"example.com", "xample.com", ".", false}, + {"example.com", "example", ".", false}, + {"example.com", "COM", "example", true}, + {"EXAMPLE.COM", "com", "EXAMPLE", true}, + } { + tmp, ok := mustParseName(test.name).TrimSuffix(mustParseName(test.suffix)) + trimmed := tmp.String() + if ok != test.ok || trimmed != test.trimmed { + t.Errorf("TrimSuffix %+q %+q returned (%+q, %v), expected (%+q, %v)", + test.name, test.suffix, trimmed, ok, test.trimmed, test.ok) + continue + } + } +} + +func TestReadName(t *testing.T) { + // Good tests. + for _, test := range []struct { + start int64 + end int64 + input string + s string + }{ + // Empty name. + {0, 1, "\x00abcd", "."}, + // No pointers. + {12, 25, "AAAABBBBCCCC\x07example\x03com\x00", "example.com"}, + // Backward pointer. + {25, 31, "AAAABBBBCCCC\x07example\x03com\x00\x03sub\xc0\x0c", "sub.example.com"}, + // Forward pointer. + {0, 4, "\x01a\xc0\x04\x03bcd\x00", "a.bcd"}, + // Two backwards pointers. + {31, 38, "AAAABBBBCCCC\x07example\x03com\x00\x03sub\xc0\x0c\x04sub2\xc0\x19", "sub2.sub.example.com"}, + // Forward then backward pointer. + {25, 31, "AAAABBBBCCCC\x07example\x03com\x00\x03sub\xc0\x1f\x04sub2\xc0\x0c", "sub.sub2.example.com"}, + // Overlapping codons. + {0, 4, "\x01a\xc0\x03bcd\x00", "a.bcd"}, + // Pointer to empty label. + {0, 10, "\x07example\xc0\x0a\x00", "example"}, + {1, 11, "\x00\x07example\xc0\x00", "example"}, + // Pointer to pointer to empty label. + {0, 10, "\x07example\xc0\x0a\xc0\x0c\x00", "example"}, + {1, 11, "\x00\x07example\xc0\x0c\xc0\x00", "example"}, + } { + r := bytes.NewReader([]byte(test.input)) + _, err := r.Seek(test.start, io.SeekStart) + if err != nil { + panic(err) + } + name, err := readName(r) + if err != nil { + t.Errorf("%+q returned error %s", test.input, err) + continue + } + s := name.String() + if s != test.s { + t.Errorf("%+q returned %+q, expected %+q", test.input, s, test.s) + continue + } + cur, _ := r.Seek(0, io.SeekCurrent) + if cur != test.end { + t.Errorf("%+q left offset %d, expected %d", test.input, cur, test.end) + continue + } + } + + // Bad tests. + for _, test := range []struct { + start int64 + input string + err error + }{ + {0, "", io.ErrUnexpectedEOF}, + // Reserved label type. + {0, "\x80example", ErrReservedLabelType}, + // Reserved label type. + {0, "\x40example", ErrReservedLabelType}, + // No Terminating empty label. + {0, "\x07example\x03com", io.ErrUnexpectedEOF}, + // Pointer past end of buffer. + {0, "\x07example\xc0\xff", io.ErrUnexpectedEOF}, + // Pointer to self. + {0, "\x07example\x03com\xc0\x0c", ErrTooManyPointers}, + // Pointer to self with intermediate label. + {0, "\x07example\x03com\xc0\x08", ErrTooManyPointers}, + // Two pointers that point to each other. + {0, "\xc0\x02\xc0\x00", ErrTooManyPointers}, + // Two pointers that point to each other, with intermediate labels. + {0, "\x01a\xc0\x04\x01b\xc0\x00", ErrTooManyPointers}, + // EOF while reading label. + {0, "\x0aexample", io.ErrUnexpectedEOF}, + // EOF before second byte of pointer. + {0, "\xc0", io.ErrUnexpectedEOF}, + {0, "\x07example\xc0", io.ErrUnexpectedEOF}, + } { + r := bytes.NewReader([]byte(test.input)) + _, err := r.Seek(test.start, io.SeekStart) + if err != nil { + panic(err) + } + name, err := readName(r) + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + if err != test.err { + t.Errorf("%+q returned (%+q, %v), expected %v", test.input, name, err, test.err) + continue + } + } +} + +func mustParseName(s string) Name { + name, err := ParseName(s) + if err != nil { + panic(err) + } + return name +} + +func questionsEqual(a, b *Question) bool { + if !namesEqual(a.Name, b.Name) { + return false + } + if a.Type != b.Type || a.Class != b.Class { + return false + } + return true +} + +func rrsEqual(a, b *RR) bool { + if !namesEqual(a.Name, b.Name) { + return false + } + if a.Type != b.Type || a.Class != b.Class || a.TTL != b.TTL { + return false + } + if !bytes.Equal(a.Data, b.Data) { + return false + } + return true +} + +func messagesEqual(a, b *Message) bool { + if a.ID != b.ID || a.Flags != b.Flags { + return false + } + if len(a.Question) != len(b.Question) { + return false + } + for i := 0; i < len(a.Question); i++ { + if !questionsEqual(&a.Question[i], &b.Question[i]) { + return false + } + } + for _, rec := range []struct{ rrA, rrB []RR }{ + {a.Answer, b.Answer}, + {a.Authority, b.Authority}, + {a.Additional, b.Additional}, + } { + if len(rec.rrA) != len(rec.rrB) { + return false + } + for i := 0; i < len(rec.rrA); i++ { + if !rrsEqual(&rec.rrA[i], &rec.rrB[i]) { + return false + } + } + } + return true +} + +func TestMessageFromWireFormat(t *testing.T) { + for _, test := range []struct { + buf string + expected Message + err error + }{ + { + "\x12\x34", + Message{}, + io.ErrUnexpectedEOF, + }, + { + "\x12\x34\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03www\x07example\x03com\x00\x00\x01\x00\x01", + Message{ + ID: 0x1234, + Flags: 0x0100, + Question: []Question{ + { + Name: mustParseName("www.example.com"), + Type: 1, + Class: 1, + }, + }, + Answer: []RR{}, + Authority: []RR{}, + Additional: []RR{}, + }, + nil, + }, + { + "\x12\x34\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03www\x07example\x03com\x00\x00\x01\x00\x01X", + Message{}, + ErrTrailingBytes, + }, + { + "\x12\x34\x81\x80\x00\x01\x00\x01\x00\x00\x00\x00\x03www\x07example\x03com\x00\x00\x01\x00\x01\x03www\x07example\x03com\x00\x00\x01\x00\x01\x00\x00\x00\x80\x00\x04\xc0\x00\x02\x01", + Message{ + ID: 0x1234, + Flags: 0x8180, + Question: []Question{ + { + Name: mustParseName("www.example.com"), + Type: 1, + Class: 1, + }, + }, + Answer: []RR{ + { + Name: mustParseName("www.example.com"), + Type: 1, + Class: 1, + TTL: 128, + Data: []byte{192, 0, 2, 1}, + }, + }, + Authority: []RR{}, + Additional: []RR{}, + }, + nil, + }, + } { + message, err := MessageFromWireFormat([]byte(test.buf)) + if err != test.err || (err == nil && !messagesEqual(&message, &test.expected)) { + t.Errorf("%+q\nreturned (%+v, %v)\nexpected (%+v, %v)", + test.buf, message, err, test.expected, test.err) + continue + } + } +} + +func TestMessageWireFormatRoundTrip(t *testing.T) { + for _, message := range []Message{ + { + ID: 0x1234, + Flags: 0x0100, + Question: []Question{ + { + Name: mustParseName("www.example.com"), + Type: 1, + Class: 1, + }, + { + Name: mustParseName("www2.example.com"), + Type: 2, + Class: 2, + }, + }, + Answer: []RR{ + { + Name: mustParseName("abc"), + Type: 2, + Class: 3, + TTL: 0xffffffff, + Data: []byte{1}, + }, + { + Name: mustParseName("xyz"), + Type: 2, + Class: 3, + TTL: 255, + Data: []byte{}, + }, + }, + Authority: []RR{ + { + Name: mustParseName("."), + Type: 65535, + Class: 65535, + TTL: 0, + Data: []byte("XXXXXXXXXXXXXXXXXXX"), + }, + }, + Additional: []RR{}, + }, + } { + buf, err := message.WireFormat() + if err != nil { + t.Errorf("%+v cannot make wire format: %v", message, err) + continue + } + message2, err := MessageFromWireFormat(buf) + if err != nil { + t.Errorf("%+q cannot parse wire format: %v", buf, err) + continue + } + if !messagesEqual(&message, &message2) { + t.Errorf("messages unequal\nbefore: %+v\n after: %+v", message, message2) + continue + } + } +} + +func TestDecodeRDataTXT(t *testing.T) { + for _, test := range []struct { + p []byte + decoded []byte + err error + }{ + {[]byte{}, nil, io.ErrUnexpectedEOF}, + {[]byte("\x00"), []byte{}, nil}, + {[]byte("\x01"), nil, io.ErrUnexpectedEOF}, + } { + decoded, err := DecodeRDataTXT(test.p) + if err != test.err || (err == nil && !bytes.Equal(decoded, test.decoded)) { + t.Errorf("%+q\nreturned (%+q, %v)\nexpected (%+q, %v)", + test.p, decoded, err, test.decoded, test.err) + continue + } + } +} + +func TestEncodeRDataTXT(t *testing.T) { + // Encoding 0 bytes needs to return at least a single length octet of + // zero, not an empty slice. + p := make([]byte, 0) + encoded := EncodeRDataTXT(p) + if len(encoded) < 0 { + t.Errorf("EncodeRDataTXT(%v) returned %v", p, encoded) + } + + // 255 bytes should be able to be encoded into 256 bytes. + p = make([]byte, 255) + encoded = EncodeRDataTXT(p) + if len(encoded) > 256 { + t.Errorf("EncodeRDataTXT(%d bytes) returned %d bytes", len(p), len(encoded)) + } +} + +func TestRDataTXTRoundTrip(t *testing.T) { + for _, p := range [][]byte{ + {}, + []byte("\x00"), + { + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, + 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, + 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, + 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, + 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, + 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, + 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf, + 0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf, + 0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf, + 0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf, + 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, + }, + } { + rdata := EncodeRDataTXT(p) + decoded, err := DecodeRDataTXT(rdata) + if err != nil || !bytes.Equal(decoded, p) { + t.Errorf("%+q returned (%+q, %v)", p, decoded, err) + continue + } + } +} diff --git a/transport/internet/finalmask/xdns/server.go b/transport/internet/finalmask/xdns/server.go new file mode 100644 index 00000000..cf2e54f4 --- /dev/null +++ b/transport/internet/finalmask/xdns/server.go @@ -0,0 +1,567 @@ +package xdns + +import ( + "bytes" + "context" + "encoding/binary" + "io" + "net" + "sync" + "time" + + "github.com/amnezia-vpn/amnezia-xray-core/common/errors" +) + +const ( + idleTimeout = 2 * time.Minute + responseTTL = 60 + maxResponseDelay = 1 * time.Second +) + +var ( + maxUDPPayload = 1280 - 40 - 8 + maxEncodedPayload = computeMaxEncodedPayload(maxUDPPayload) +) + +func clientIDToAddr(clientID [8]byte) *net.UDPAddr { + ip := make(net.IP, 16) + + copy(ip, []byte{0xfd, 0x00, 0, 0, 0, 0, 0, 0}) + copy(ip[8:], clientID[:]) + + return &net.UDPAddr{ + IP: ip, + } +} + +type record struct { + Resp *Message + Addr net.Addr + // ClientID [8]byte + ClientAddr net.Addr +} + +type queue struct { + lash time.Time + queue chan []byte + stash chan []byte +} + +type xdnsConnServer struct { + conn net.PacketConn + + domain Name + + ch chan *record + readQueue chan *packet + writeQueueMap map[string]*queue + + closed bool + mutex sync.Mutex +} + +func NewConnServer(c *Config, raw net.PacketConn, end bool) (net.PacketConn, error) { + if !end { + return nil, errors.New("xdns requires being at the outermost level") + } + + domain, err := ParseName(c.Domain) + if err != nil { + return nil, err + } + + conn := &xdnsConnServer{ + conn: raw, + + domain: domain, + + ch: make(chan *record, 100), + readQueue: make(chan *packet, 128), + writeQueueMap: make(map[string]*queue), + } + + go conn.clean() + go conn.recvLoop() + go conn.sendLoop() + + return conn, nil +} + +func (c *xdnsConnServer) clean() { + f := func() bool { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return true + } + + now := time.Now() + + for key, q := range c.writeQueueMap { + if now.Sub(q.lash) >= idleTimeout { + close(q.queue) + close(q.stash) + delete(c.writeQueueMap, key) + } + } + + return false + } + + for { + time.Sleep(idleTimeout / 2) + if f() { + return + } + } +} + +func (c *xdnsConnServer) ensureQueue(addr net.Addr) *queue { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return nil + } + + q, ok := c.writeQueueMap[addr.String()] + if !ok { + q = &queue{ + queue: make(chan []byte, 128), + stash: make(chan []byte, 1), + } + c.writeQueueMap[addr.String()] = q + } + q.lash = time.Now() + + return q +} + +func (c *xdnsConnServer) stash(queue *queue, p []byte) { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return + } + + select { + case queue.stash <- p: + default: + } +} + +func (c *xdnsConnServer) recvLoop() { + for { + if c.closed { + break + } + + var buf [4096]byte + n, addr, err := c.conn.ReadFrom(buf[:]) + if err != nil { + continue + } + + query, err := MessageFromWireFormat(buf[:n]) + if err != nil { + continue + } + + resp, payload := responseFor(&query, c.domain) + + var clientID [8]byte + n = copy(clientID[:], payload) + payload = payload[n:] + if n == len(clientID) { + r := bytes.NewReader(payload) + for { + p, err := nextPacketServer(r) + if err != nil { + break + } + + buf := make([]byte, len(p)) + copy(buf, p) + select { + case c.readQueue <- &packet{ + p: buf, + addr: clientIDToAddr(clientID), + }: + default: + } + } + } else { + if resp != nil && resp.Rcode() == RcodeNoError { + resp.Flags |= RcodeNameError + } + } + + if resp != nil { + select { + case c.ch <- &record{resp, addr, clientIDToAddr(clientID)}: + default: + } + } + } + + close(c.ch) + close(c.readQueue) +} + +func (c *xdnsConnServer) sendLoop() { + var nextRec *record + for { + rec := nextRec + nextRec = nil + + if rec == nil { + var ok bool + rec, ok = <-c.ch + if !ok { + break + } + } + + if rec.Resp.Rcode() == RcodeNoError && len(rec.Resp.Question) == 1 { + rec.Resp.Answer = []RR{ + { + Name: rec.Resp.Question[0].Name, + Type: rec.Resp.Question[0].Type, + Class: rec.Resp.Question[0].Class, + TTL: responseTTL, + Data: nil, + }, + } + + var payload bytes.Buffer + limit := maxEncodedPayload + timer := time.NewTimer(maxResponseDelay) + for { + queue := c.ensureQueue(rec.ClientAddr) + if queue == nil { + return + } + + var p []byte + + select { + case p = <-queue.stash: + default: + select { + case p = <-queue.stash: + case p = <-queue.queue: + default: + select { + case p = <-queue.stash: + case p = <-queue.queue: + case <-timer.C: + case nextRec = <-c.ch: + } + } + } + + timer.Reset(0) + + if len(p) == 0 { + break + } + + limit -= 2 + len(p) + if payload.Len() == 0 { + + } else if limit < 0 { + c.stash(queue, p) + + break + } + + if int(uint16(len(p))) != len(p) { + panic(len(p)) + } + + _ = binary.Write(&payload, binary.BigEndian, uint16(len(p))) + payload.Write(p) + } + timer.Stop() + + rec.Resp.Answer[0].Data = EncodeRDataTXT(payload.Bytes()) + } + + buf, err := rec.Resp.WireFormat() + if err != nil { + continue + } + + if len(buf) > maxUDPPayload { + errors.LogDebug(context.Background(), "xdns server truncate ", len(buf)) + buf = buf[:maxUDPPayload] + buf[2] |= 0x02 + } + + if c.closed { + return + } + + _, _ = c.conn.WriteTo(buf, rec.Addr) + } +} + +func (c *xdnsConnServer) Size() int32 { + return 0 +} + +func (c *xdnsConnServer) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + packet, ok := <-c.readQueue + if !ok { + return 0, nil, io.EOF + } + n = copy(p, packet.p) + if n != len(packet.p) { + return 0, nil, io.ErrShortBuffer + } + return n, packet.addr, nil +} + +func (c *xdnsConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) { + q := c.ensureQueue(addr) + if q == nil { + return 0, errors.New("xdns closed") + } + + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return 0, errors.New("xdns closed") + } + + buf := make([]byte, len(p)) + copy(buf, p) + + select { + case q.queue <- buf: + return len(p), nil + default: + return 0, errors.New("xdns queue full") + } +} + +func (c *xdnsConnServer) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return nil + } + + c.closed = true + for key, q := range c.writeQueueMap { + close(q.queue) + close(q.stash) + delete(c.writeQueueMap, key) + } + + return c.conn.Close() +} + +func (c *xdnsConnServer) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *xdnsConnServer) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +func (c *xdnsConnServer) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +func (c *xdnsConnServer) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +func nextPacketServer(r *bytes.Reader) ([]byte, error) { + eof := func(err error) error { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return err + } + + for { + prefix, err := r.ReadByte() + if err != nil { + return nil, err + } + if prefix >= 224 { + paddingLen := prefix - 224 + _, err := io.CopyN(io.Discard, r, int64(paddingLen)) + if err != nil { + return nil, eof(err) + } + } else { + p := make([]byte, int(prefix)) + _, err = io.ReadFull(r, p) + return p, eof(err) + } + } +} + +func responseFor(query *Message, domain Name) (*Message, []byte) { + resp := &Message{ + ID: query.ID, + Flags: 0x8000, + Question: query.Question, + } + + if query.Flags&0x8000 != 0 { + return nil, nil + } + + payloadSize := 0 + for _, rr := range query.Additional { + if rr.Type != RRTypeOPT { + continue + } + if len(resp.Additional) != 0 { + resp.Flags |= RcodeFormatError + return resp, nil + } + resp.Additional = append(resp.Additional, RR{ + Name: Name{}, + Type: RRTypeOPT, + Class: 4096, + TTL: 0, + Data: []byte{}, + }) + additional := &resp.Additional[0] + + version := (rr.TTL >> 16) & 0xff + if version != 0 { + resp.Flags |= ExtendedRcodeBadVers & 0xf + additional.TTL = (ExtendedRcodeBadVers >> 4) << 24 + return resp, nil + } + + payloadSize = int(rr.Class) + } + if payloadSize < 512 { + payloadSize = 512 + } + + if len(query.Question) != 1 { + resp.Flags |= RcodeFormatError + return resp, nil + } + question := query.Question[0] + + prefix, ok := question.Name.TrimSuffix(domain) + if !ok { + resp.Flags |= RcodeNameError + return resp, nil + } + resp.Flags |= 0x0400 + + if query.Opcode() != 0 { + resp.Flags |= RcodeNotImplemented + return resp, nil + } + + if question.Type != RRTypeTXT { + resp.Flags |= RcodeNameError + return resp, nil + } + + encoded := bytes.ToUpper(bytes.Join(prefix, nil)) + payload := make([]byte, base32Encoding.DecodedLen(len(encoded))) + n, err := base32Encoding.Decode(payload, encoded) + if err != nil { + resp.Flags |= RcodeNameError + return resp, nil + } + payload = payload[:n] + + if payloadSize < maxUDPPayload { + resp.Flags |= RcodeFormatError + return resp, nil + } + + return resp, payload +} + +func computeMaxEncodedPayload(limit int) int { + maxLengthName, err := NewName([][]byte{ + []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), + []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), + []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), + []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), + }) + if err != nil { + panic(err) + } + { + n := 0 + for _, label := range maxLengthName { + n += len(label) + 1 + } + n += 1 + if n != 255 { + panic("computeMaxEncodedPayload n != 255") + } + } + + queryLimit := uint16(limit) + if int(queryLimit) != limit { + queryLimit = 0xffff + } + query := &Message{ + Question: []Question{ + { + Name: maxLengthName, + Type: RRTypeTXT, + Class: RRTypeTXT, + }, + }, + + Additional: []RR{ + { + Name: Name{}, + Type: RRTypeOPT, + Class: queryLimit, + TTL: 0, + Data: []byte{}, + }, + }, + } + resp, _ := responseFor(query, [][]byte{}) + + resp.Answer = []RR{ + { + Name: query.Question[0].Name, + Type: query.Question[0].Type, + Class: query.Question[0].Class, + TTL: responseTTL, + Data: nil, + }, + } + + low := 0 + high := 32768 + for low+1 < high { + mid := (low + high) / 2 + resp.Answer[0].Data = EncodeRDataTXT(make([]byte, mid)) + buf, err := resp.WireFormat() + if err != nil { + panic(err) + } + if len(buf) <= limit { + low = mid + } else { + high = mid + } + } + + return low +} diff --git a/transport/internet/finalmask/xicmp/client.go b/transport/internet/finalmask/xicmp/client.go new file mode 100644 index 00000000..bb8615e2 --- /dev/null +++ b/transport/internet/finalmask/xicmp/client.go @@ -0,0 +1,350 @@ +package xicmp + +import ( + "context" + "io" + "net" + "strings" + "sync" + "time" + + "github.com/amnezia-vpn/amnezia-xray-core/common/crypto" + "github.com/amnezia-vpn/amnezia-xray-core/common/errors" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +const ( + initPollDelay = 500 * time.Millisecond + maxPollDelay = 10 * time.Second + pollDelayMultiplier = 2.0 + pollLimit = 16 + windowSize = 1000 +) + +type packet struct { + p []byte + addr net.Addr +} + +type seqStatus struct { + needSeqByte bool + seqByte byte +} + +type xicmpConnClient struct { + conn net.PacketConn + icmpConn *icmp.PacketConn + + typ icmp.Type + id int + seq int + proto int + seqStatus map[int]*seqStatus + + pollChan chan struct{} + readQueue chan *packet + writeQueue chan *packet + + closed bool + mutex sync.Mutex +} + +func NewConnClient(c *Config, raw net.PacketConn, end bool) (net.PacketConn, error) { + if !end { + return nil, errors.New("xicmp requires being at the outermost level") + } + + network := "ip4:icmp" + typ := icmp.Type(ipv4.ICMPTypeEcho) + proto := 1 + if strings.Contains(c.Ip, ":") { + network = "ip6:ipv6-icmp" + typ = ipv6.ICMPTypeEchoRequest + proto = 58 + } + + icmpConn, err := icmp.ListenPacket(network, c.Ip) + if err != nil { + return nil, errors.New("xicmp listen err").Base(err) + } + + if c.Id == 0 { + c.Id = int32(crypto.RandBetween(0, 65535)) + } + + conn := &xicmpConnClient{ + conn: raw, + icmpConn: icmpConn, + + typ: typ, + id: int(c.Id), + seq: 1, + proto: proto, + seqStatus: make(map[int]*seqStatus), + + pollChan: make(chan struct{}, pollLimit), + readQueue: make(chan *packet, 128), + writeQueue: make(chan *packet, 128), + } + + go conn.recvLoop() + go conn.sendLoop() + + return conn, nil +} + +func (c *xicmpConnClient) encode(p []byte) ([]byte, error) { + c.mutex.Lock() + defer c.mutex.Unlock() + + needSeqByte := false + var seqByte byte + data := p + if len(p) > 0 { + needSeqByte = true + seqByte = p[0] + } + + msg := icmp.Message{ + Type: c.typ, + Code: 0, + Body: &icmp.Echo{ + ID: c.id, + Seq: c.seq, + Data: data, + }, + } + + buf, err := msg.Marshal(nil) + if err != nil { + return nil, err + } + + if len(buf) > 8192 { + return nil, errors.New("xicmp len(buf) > 8192") + } + + c.seqStatus[c.seq] = &seqStatus{ + needSeqByte: needSeqByte, + seqByte: seqByte, + } + + delete(c.seqStatus, int(uint16(c.seq-windowSize))) + + c.seq++ + + if c.seq == 65536 { + delete(c.seqStatus, int(uint16(c.seq-windowSize))) + c.seq = 1 + } + + return buf, nil +} + +func (c *xicmpConnClient) recvLoop() { + for { + if c.closed { + break + } + + var buf [8192]byte + + n, addr, err := c.icmpConn.ReadFrom(buf[:]) + if err != nil { + continue + } + + msg, err := icmp.ParseMessage(c.proto, buf[:n]) + if err != nil { + continue + } + + if msg.Type != ipv4.ICMPTypeEchoReply && msg.Type != ipv6.ICMPTypeEchoReply { + continue + } + + echo, ok := msg.Body.(*icmp.Echo) + if !ok { + continue + } + + c.mutex.Lock() + seqStatus, ok := c.seqStatus[echo.Seq] + c.mutex.Unlock() + + if !ok { + continue + } + + if seqStatus.needSeqByte { + if len(echo.Data) <= 1 { + continue + } + if echo.Data[0] == seqStatus.seqByte { + continue + } + echo.Data = echo.Data[1:] + } + + if len(echo.Data) > 0 { + c.mutex.Lock() + delete(c.seqStatus, echo.Seq) + c.mutex.Unlock() + + buf := make([]byte, len(echo.Data)) + copy(buf, echo.Data) + select { + case c.readQueue <- &packet{ + p: buf, + addr: &net.UDPAddr{IP: addr.(*net.IPAddr).IP}, + }: + default: + } + + select { + case c.pollChan <- struct{}{}: + default: + } + } + } + + close(c.pollChan) + close(c.readQueue) +} + +func (c *xicmpConnClient) sendLoop() { + var addr net.Addr + + pollDelay := initPollDelay + pollTimer := time.NewTimer(pollDelay) + for { + var p *packet + pollTimerExpired := false + + select { + case p = <-c.writeQueue: + default: + select { + case p = <-c.writeQueue: + case <-c.pollChan: + case <-pollTimer.C: + pollTimerExpired = true + } + } + + if p != nil { + addr = p.addr + + select { + case <-c.pollChan: + default: + } + } else if addr != nil { + encoded, _ := c.encode(nil) + p = &packet{ + p: encoded, + addr: addr, + } + } + + if pollTimerExpired { + pollDelay = time.Duration(float64(pollDelay) * pollDelayMultiplier) + if pollDelay > maxPollDelay { + pollDelay = maxPollDelay + } + } else { + if !pollTimer.Stop() { + <-pollTimer.C + } + pollDelay = initPollDelay + } + pollTimer.Reset(pollDelay) + + if c.closed { + return + } + + if p != nil { + _, err := c.icmpConn.WriteTo(p.p, p.addr) + if err != nil { + errors.LogDebug(context.Background(), "xicmp writeto err ", err) + } + } + } +} + +func (c *xicmpConnClient) Size() int32 { + return 0 +} + +func (c *xicmpConnClient) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + packet, ok := <-c.readQueue + if !ok { + return 0, nil, io.EOF + } + n = copy(p, packet.p) + if n != len(packet.p) { + return 0, nil, io.ErrShortBuffer + } + return n, packet.addr, nil +} + +func (c *xicmpConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) { + encoded, err := c.encode(p) + if err != nil { + return 0, errors.New("xicmp encode").Base(err) + } + + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return 0, errors.New("xicmp closed") + } + + select { + case c.writeQueue <- &packet{ + p: encoded, + addr: &net.IPAddr{IP: addr.(*net.UDPAddr).IP}, + }: + return len(p), nil + default: + return 0, errors.New("xicmp queue full") + } +} + +func (c *xicmpConnClient) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return nil + } + + c.closed = true + close(c.writeQueue) + + _ = c.icmpConn.Close() + return c.conn.Close() +} + +func (c *xicmpConnClient) LocalAddr() net.Addr { + return &net.UDPAddr{ + IP: c.icmpConn.LocalAddr().(*net.IPAddr).IP, + Port: c.id, + } +} + +func (c *xicmpConnClient) SetDeadline(t time.Time) error { + return c.icmpConn.SetDeadline(t) +} + +func (c *xicmpConnClient) SetReadDeadline(t time.Time) error { + return c.icmpConn.SetReadDeadline(t) +} + +func (c *xicmpConnClient) SetWriteDeadline(t time.Time) error { + return c.icmpConn.SetWriteDeadline(t) +} diff --git a/transport/internet/finalmask/xicmp/config.go b/transport/internet/finalmask/xicmp/config.go new file mode 100644 index 00000000..81a483af --- /dev/null +++ b/transport/internet/finalmask/xicmp/config.go @@ -0,0 +1,16 @@ +package xicmp + +import ( + "net" +) + +func (c *Config) UDP() { +} + +func (c *Config) WrapPacketConnClient(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnClient(c, raw, end) +} + +func (c *Config) WrapPacketConnServer(raw net.PacketConn, first bool, leaveSize int32, end bool) (net.PacketConn, error) { + return NewConnServer(c, raw, end) +} diff --git a/transport/internet/finalmask/xicmp/config.pb.go b/transport/internet/finalmask/xicmp/config.pb.go new file mode 100644 index 00000000..46bef55a --- /dev/null +++ b/transport/internet/finalmask/xicmp/config.pb.go @@ -0,0 +1,132 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v3.21.12 +// source: transport/internet/finalmask/xicmp/config.proto + +package xicmp + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + Ip string `protobuf:"bytes,1,opt,name=ip,proto3" json:"ip,omitempty"` + Id int32 `protobuf:"varint,2,opt,name=id,proto3" json:"id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Config) Reset() { + *x = Config{} + mi := &file_transport_internet_finalmask_xicmp_config_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Config) ProtoMessage() {} + +func (x *Config) ProtoReflect() protoreflect.Message { + mi := &file_transport_internet_finalmask_xicmp_config_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Config.ProtoReflect.Descriptor instead. +func (*Config) Descriptor() ([]byte, []int) { + return file_transport_internet_finalmask_xicmp_config_proto_rawDescGZIP(), []int{0} +} + +func (x *Config) GetIp() string { + if x != nil { + return x.Ip + } + return "" +} + +func (x *Config) GetId() int32 { + if x != nil { + return x.Id + } + return 0 +} + +var File_transport_internet_finalmask_xicmp_config_proto protoreflect.FileDescriptor + +const file_transport_internet_finalmask_xicmp_config_proto_rawDesc = "" + + "\n" + + "/transport/internet/finalmask/xicmp/config.proto\x12'xray.transport.internet.finalmask.xicmp\"(\n" + + "\x06Config\x12\x0e\n" + + "\x02ip\x18\x01 \x01(\tR\x02ip\x12\x0e\n" + + "\x02id\x18\x02 \x01(\x05R\x02idB\xa6\x01\n" + + "+com.xray.transport.internet.finalmask.xicmpP\x01ZKgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/xicmp\xaa\x02'Xray.Transport.Internet.Finalmask.Xicmpb\x06proto3" + +var ( + file_transport_internet_finalmask_xicmp_config_proto_rawDescOnce sync.Once + file_transport_internet_finalmask_xicmp_config_proto_rawDescData []byte +) + +func file_transport_internet_finalmask_xicmp_config_proto_rawDescGZIP() []byte { + file_transport_internet_finalmask_xicmp_config_proto_rawDescOnce.Do(func() { + file_transport_internet_finalmask_xicmp_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_xicmp_config_proto_rawDesc), len(file_transport_internet_finalmask_xicmp_config_proto_rawDesc))) + }) + return file_transport_internet_finalmask_xicmp_config_proto_rawDescData +} + +var file_transport_internet_finalmask_xicmp_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_transport_internet_finalmask_xicmp_config_proto_goTypes = []any{ + (*Config)(nil), // 0: xray.transport.internet.finalmask.xicmp.Config +} +var file_transport_internet_finalmask_xicmp_config_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_transport_internet_finalmask_xicmp_config_proto_init() } +func file_transport_internet_finalmask_xicmp_config_proto_init() { + if File_transport_internet_finalmask_xicmp_config_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_finalmask_xicmp_config_proto_rawDesc), len(file_transport_internet_finalmask_xicmp_config_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_transport_internet_finalmask_xicmp_config_proto_goTypes, + DependencyIndexes: file_transport_internet_finalmask_xicmp_config_proto_depIdxs, + MessageInfos: file_transport_internet_finalmask_xicmp_config_proto_msgTypes, + }.Build() + File_transport_internet_finalmask_xicmp_config_proto = out.File + file_transport_internet_finalmask_xicmp_config_proto_goTypes = nil + file_transport_internet_finalmask_xicmp_config_proto_depIdxs = nil +} diff --git a/transport/internet/finalmask/xicmp/config.proto b/transport/internet/finalmask/xicmp/config.proto new file mode 100644 index 00000000..e8ca5abb --- /dev/null +++ b/transport/internet/finalmask/xicmp/config.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package xray.transport.internet.finalmask.xicmp; +option csharp_namespace = "Xray.Transport.Internet.Finalmask.Xicmp"; +option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/finalmask/xicmp"; +option java_package = "com.xray.transport.internet.finalmask.xicmp"; +option java_multiple_files = true; + +message Config { + string ip = 1; + int32 id = 2; +} + diff --git a/transport/internet/finalmask/xicmp/server.go b/transport/internet/finalmask/xicmp/server.go new file mode 100644 index 00000000..3ad41503 --- /dev/null +++ b/transport/internet/finalmask/xicmp/server.go @@ -0,0 +1,377 @@ +package xicmp + +import ( + "context" + "io" + "net" + "strings" + "sync" + "time" + + "github.com/amnezia-vpn/amnezia-xray-core/common/crypto" + "github.com/amnezia-vpn/amnezia-xray-core/common/errors" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +const ( + idleTimeout = 2 * time.Minute + maxResponseDelay = 1 * time.Second +) + +type record struct { + id int + seq int + needSeqByte bool + seqByte byte + addr net.Addr +} + +type queue struct { + lash time.Time + queue chan []byte +} + +type xicmpConnServer struct { + conn net.PacketConn + icmpConn *icmp.PacketConn + + typ icmp.Type + proto int + config *Config + + ch chan *record + readQueue chan *packet + writeQueueMap map[string]*queue + + closed bool + mutex sync.Mutex +} + +func NewConnServer(c *Config, raw net.PacketConn, end bool) (net.PacketConn, error) { + if !end { + return nil, errors.New("xicmp requires being at the outermost level") + } + + network := "ip4:icmp" + typ := icmp.Type(ipv4.ICMPTypeEchoReply) + proto := 1 + if strings.Contains(c.Ip, ":") { + network = "ip6:ipv6-icmp" + typ = ipv6.ICMPTypeEchoReply + proto = 58 + } + + icmpConn, err := icmp.ListenPacket(network, c.Ip) + if err != nil { + return nil, errors.New("xicmp listen err").Base(err) + } + + conn := &xicmpConnServer{ + conn: raw, + icmpConn: icmpConn, + + typ: typ, + proto: proto, + config: c, + + ch: make(chan *record, 100), + readQueue: make(chan *packet, 128), + writeQueueMap: make(map[string]*queue), + } + + go conn.clean() + go conn.recvLoop() + go conn.sendLoop() + + return conn, nil +} + +func (c *xicmpConnServer) clean() { + f := func() bool { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return true + } + + now := time.Now() + + for key, q := range c.writeQueueMap { + if now.Sub(q.lash) >= idleTimeout { + close(q.queue) + delete(c.writeQueueMap, key) + } + } + + return false + } + + for { + time.Sleep(idleTimeout / 2) + if f() { + return + } + } +} + +func (c *xicmpConnServer) ensureQueue(addr net.Addr) *queue { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return nil + } + + q, ok := c.writeQueueMap[addr.String()] + if !ok { + q = &queue{ + queue: make(chan []byte, 128), + } + c.writeQueueMap[addr.String()] = q + } + q.lash = time.Now() + + return q +} + +func (c *xicmpConnServer) encode(p []byte, id int, seq int, needSeqByte bool, seqByte byte) ([]byte, error) { + data := p + if needSeqByte { + b2 := c.randUntil(seqByte) + data = append([]byte{b2}, p...) + } + + msg := icmp.Message{ + Type: c.typ, + Code: 0, + Body: &icmp.Echo{ + ID: id, + Seq: seq, + Data: data, + }, + } + + buf, err := msg.Marshal(nil) + if err != nil { + return nil, err + } + + if len(buf) > 8192 { + return nil, errors.New("xicmp len(buf) > 8192") + } + + return buf, nil +} + +func (c *xicmpConnServer) randUntil(b1 byte) byte { + b2 := byte(crypto.RandBetween(0, 255)) + for { + if b2 != b1 { + return b2 + } + b2 = byte(crypto.RandBetween(0, 255)) + } +} + +func (c *xicmpConnServer) recvLoop() { + for { + if c.closed { + break + } + + var buf [8192]byte + + n, addr, err := c.icmpConn.ReadFrom(buf[:]) + if err != nil { + continue + } + + msg, err := icmp.ParseMessage(c.proto, buf[:n]) + if err != nil { + continue + } + + if msg.Type != ipv4.ICMPTypeEcho && msg.Type != ipv6.ICMPTypeEchoRequest { + continue + } + + echo, ok := msg.Body.(*icmp.Echo) + if !ok { + continue + } + + if c.config.Id != 0 && echo.ID != int(c.config.Id) { + continue + } + + needSeqByte := false + var seqByte byte + + if len(echo.Data) > 0 { + needSeqByte = true + seqByte = echo.Data[0] + + buf := make([]byte, len(echo.Data)) + copy(buf, echo.Data) + select { + case c.readQueue <- &packet{ + p: buf, + addr: &net.UDPAddr{ + IP: addr.(*net.IPAddr).IP, + Port: echo.ID, + }, + }: + default: + } + } + + select { + case c.ch <- &record{ + id: echo.ID, + seq: echo.Seq, + needSeqByte: needSeqByte, + seqByte: seqByte, + addr: &net.UDPAddr{ + IP: addr.(*net.IPAddr).IP, + Port: echo.ID, + }, + }: + default: + } + } + + close(c.ch) + close(c.readQueue) +} + +func (c *xicmpConnServer) sendLoop() { + var nextRec *record + for { + rec := nextRec + nextRec = nil + + if rec == nil { + var ok bool + rec, ok = <-c.ch + if !ok { + break + } + } + + queue := c.ensureQueue(rec.addr) + if queue == nil { + return + } + + var p []byte + + timer := time.NewTimer(maxResponseDelay) + + select { + case p = <-queue.queue: + default: + select { + case p = <-queue.queue: + case <-timer.C: + case nextRec = <-c.ch: + } + } + + timer.Stop() + + if len(p) == 0 { + continue + } + + buf, err := c.encode(p, rec.id, rec.seq, rec.needSeqByte, rec.seqByte) + if err != nil { + continue + } + + if c.closed { + return + } + + _, err = c.icmpConn.WriteTo(buf, &net.IPAddr{IP: rec.addr.(*net.UDPAddr).IP}) + if err != nil { + errors.LogDebug(context.Background(), "xicmp writeto err ", err) + } + } +} + +func (c *xicmpConnServer) Size() int32 { + return 0 +} + +func (c *xicmpConnServer) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + packet, ok := <-c.readQueue + if !ok { + return 0, nil, io.EOF + } + n = copy(p, packet.p) + if n != len(packet.p) { + return 0, nil, io.ErrShortBuffer + } + return n, packet.addr, nil +} + +func (c *xicmpConnServer) WriteTo(p []byte, addr net.Addr) (n int, err error) { + q := c.ensureQueue(addr) + if q == nil { + return 0, errors.New("xicmp closed") + } + + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return 0, errors.New("xicmp closed") + } + + buf := make([]byte, len(p)) + copy(buf, p) + + select { + case q.queue <- buf: + return len(p), nil + default: + return 0, errors.New("xicmp queue full") + } +} + +func (c *xicmpConnServer) Close() error { + c.mutex.Lock() + defer c.mutex.Unlock() + + if c.closed { + return nil + } + + c.closed = true + for key, q := range c.writeQueueMap { + close(q.queue) + delete(c.writeQueueMap, key) + } + + _ = c.icmpConn.Close() + return c.conn.Close() +} + +func (c *xicmpConnServer) LocalAddr() net.Addr { + return &net.UDPAddr{IP: c.icmpConn.LocalAddr().(*net.IPAddr).IP} +} + +func (c *xicmpConnServer) SetDeadline(t time.Time) error { + return c.icmpConn.SetDeadline(t) +} + +func (c *xicmpConnServer) SetReadDeadline(t time.Time) error { + return c.icmpConn.SetReadDeadline(t) +} + +func (c *xicmpConnServer) SetWriteDeadline(t time.Time) error { + return c.icmpConn.SetWriteDeadline(t) +} diff --git a/transport/internet/finalmask/xicmp/xicmp_test.go b/transport/internet/finalmask/xicmp/xicmp_test.go new file mode 100644 index 00000000..1ac92181 --- /dev/null +++ b/transport/internet/finalmask/xicmp/xicmp_test.go @@ -0,0 +1,74 @@ +package xicmp_test + +import ( + "bytes" + "fmt" + "testing" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +func TestICMPEchoMarshal(t *testing.T) { + msg := icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &icmp.Echo{ + ID: 65535, + Seq: 65537, + Data: nil, + }, + } + ICMPTypeEcho, _ := msg.Marshal(nil) + fmt.Println("ICMPTypeEcho", len(ICMPTypeEcho), ICMPTypeEcho) + + msg = icmp.Message{ + Type: ipv4.ICMPTypeEchoReply, + Code: 0, + Body: &icmp.Echo{ + ID: 65535, + Seq: 65537, + Data: nil, + }, + } + ICMPTypeEchoReply, _ := msg.Marshal(nil) + fmt.Println("ICMPTypeEchoReply", len(ICMPTypeEchoReply), ICMPTypeEchoReply) + + msg = icmp.Message{ + Type: ipv6.ICMPTypeEchoRequest, + Code: 0, + Body: &icmp.Echo{ + ID: 65535, + Seq: 65537, + Data: nil, + }, + } + ICMPTypeEchoRequest, _ := msg.Marshal(nil) + fmt.Println("ICMPTypeEchoRequest", len(ICMPTypeEchoRequest), ICMPTypeEchoRequest) + + msg = icmp.Message{ + Type: ipv6.ICMPTypeEchoReply, + Code: 0, + Body: &icmp.Echo{ + ID: 65535, + Seq: 65537, + Data: nil, + }, + } + V6ICMPTypeEchoReply, _ := msg.Marshal(nil) + fmt.Println("V6ICMPTypeEchoReply", len(V6ICMPTypeEchoReply), V6ICMPTypeEchoReply) + + if !bytes.Equal(ICMPTypeEcho[0:2], []byte{8, 0}) || !bytes.Equal(ICMPTypeEcho[4:], []byte{255, 255, 0, 1}) { + t.Fatalf("ICMPTypeEcho Type/Code or ID/Seq mismatch: %v", ICMPTypeEcho) + } + if !bytes.Equal(ICMPTypeEchoReply[0:2], []byte{0, 0}) || !bytes.Equal(ICMPTypeEchoReply[4:], []byte{255, 255, 0, 1}) { + t.Fatalf("ICMPTypeEchoReply Type/Code or ID/Seq mismatch: %v", ICMPTypeEchoReply) + } + if !bytes.Equal(ICMPTypeEchoRequest[0:2], []byte{128, 0}) || !bytes.Equal(ICMPTypeEchoRequest[4:], []byte{255, 255, 0, 1}) { + t.Fatalf("ICMPTypeEchoRequest Type/Code or ID/Seq mismatch: %v", ICMPTypeEchoRequest) + } + if !bytes.Equal(V6ICMPTypeEchoReply[0:2], []byte{129, 0}) || !bytes.Equal(V6ICMPTypeEchoReply[4:], []byte{255, 255, 0, 1}) { + t.Fatalf("V6ICMPTypeEchoReply Type/Code or ID/Seq mismatch: %v", V6ICMPTypeEchoReply) + } +} diff --git a/transport/internet/grpc/dial.go b/transport/internet/grpc/dial.go index 5c532155..ad4b3830 100644 --- a/transport/internet/grpc/dial.go +++ b/transport/internet/grpc/dial.go @@ -10,6 +10,7 @@ import ( "github.com/amnezia-vpn/amnezia-xray-core/common/errors" "github.com/amnezia-vpn/amnezia-xray-core/common/net" "github.com/amnezia-vpn/amnezia-xray-core/common/session" + "github.com/amnezia-vpn/amnezia-xray-core/common/utils" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/grpc/encoding" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/reality" @@ -167,9 +168,11 @@ func getGrpcClient(ctx context.Context, dest net.Destination, streamSettings *in dialOptions = append(dialOptions, grpc.WithInitialWindowSize(grpcSettings.InitialWindowsSize)) } - if grpcSettings.UserAgent != "" { - dialOptions = append(dialOptions, grpc.WithUserAgent(grpcSettings.UserAgent)) + userAgent := grpcSettings.UserAgent + if userAgent == "" { + userAgent = utils.ChromeUA } + dialOptions = append(dialOptions, grpc.WithUserAgent(userAgent)) var grpcDestHost string if dest.Address.Family().IsDomain() { diff --git a/transport/internet/grpc/encoding/stream_grpc.pb.go b/transport/internet/grpc/encoding/stream_grpc.pb.go index 1fe11524..eae77c41 100644 --- a/transport/internet/grpc/encoding/stream_grpc.pb.go +++ b/transport/internet/grpc/encoding/stream_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.5.1 -// - protoc v5.28.2 +// - protoc-gen-go-grpc v1.6.1 +// - protoc v3.21.12 // source: transport/internet/grpc/encoding/stream.proto package encoding @@ -82,10 +82,10 @@ type GRPCServiceServer interface { type UnimplementedGRPCServiceServer struct{} func (UnimplementedGRPCServiceServer) Tun(grpc.BidiStreamingServer[Hunk, Hunk]) error { - return status.Errorf(codes.Unimplemented, "method Tun not implemented") + return status.Error(codes.Unimplemented, "method Tun not implemented") } func (UnimplementedGRPCServiceServer) TunMulti(grpc.BidiStreamingServer[MultiHunk, MultiHunk]) error { - return status.Errorf(codes.Unimplemented, "method TunMulti not implemented") + return status.Error(codes.Unimplemented, "method TunMulti not implemented") } func (UnimplementedGRPCServiceServer) mustEmbedUnimplementedGRPCServiceServer() {} func (UnimplementedGRPCServiceServer) testEmbeddedByValue() {} @@ -98,7 +98,7 @@ type UnsafeGRPCServiceServer interface { } func RegisterGRPCServiceServer(s grpc.ServiceRegistrar, srv GRPCServiceServer) { - // If the following call pancis, it indicates UnimplementedGRPCServiceServer was + // If the following call panics, it indicates UnimplementedGRPCServiceServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. diff --git a/transport/internet/header_test.go b/transport/internet/header_test.go deleted file mode 100644 index 272d2fad..00000000 --- a/transport/internet/header_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package internet_test - -import ( - "testing" - - "github.com/amnezia-vpn/amnezia-xray-core/common" - . "github.com/amnezia-vpn/amnezia-xray-core/transport/internet" - "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/noop" - "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/srtp" - "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/utp" - "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/wechat" - "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/wireguard" -) - -func TestAllHeadersLoadable(t *testing.T) { - testCases := []struct { - Input interface{} - Size int32 - }{ - { - Input: new(noop.Config), - Size: 0, - }, - { - Input: new(srtp.Config), - Size: 4, - }, - { - Input: new(utp.Config), - Size: 4, - }, - { - Input: new(wechat.VideoConfig), - Size: 13, - }, - { - Input: new(wireguard.WireguardConfig), - Size: 4, - }, - } - - for _, testCase := range testCases { - header, err := CreatePacketHeader(testCase.Input) - common.Must(err) - if header.Size() != testCase.Size { - t.Error("expected size ", testCase.Size, " but got ", header.Size()) - } - } -} diff --git a/transport/internet/headers/dns/config.pb.go b/transport/internet/headers/dns/config.pb.go deleted file mode 100644 index 82f9f9cb..00000000 --- a/transport/internet/headers/dns/config.pb.go +++ /dev/null @@ -1,123 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.36.11 -// protoc v3.21.12 -// source: transport/internet/headers/dns/config.proto - -package dns - -import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" - unsafe "unsafe" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -type Config struct { - state protoimpl.MessageState `protogen:"open.v1"` - Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *Config) Reset() { - *x = Config{} - mi := &file_transport_internet_headers_dns_config_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *Config) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*Config) ProtoMessage() {} - -func (x *Config) ProtoReflect() protoreflect.Message { - mi := &file_transport_internet_headers_dns_config_proto_msgTypes[0] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use Config.ProtoReflect.Descriptor instead. -func (*Config) Descriptor() ([]byte, []int) { - return file_transport_internet_headers_dns_config_proto_rawDescGZIP(), []int{0} -} - -func (x *Config) GetDomain() string { - if x != nil { - return x.Domain - } - return "" -} - -var File_transport_internet_headers_dns_config_proto protoreflect.FileDescriptor - -const file_transport_internet_headers_dns_config_proto_rawDesc = "" + - "\n" + - "+transport/internet/headers/dns/config.proto\x12#xray.transport.internet.headers.dns\" \n" + - "\x06Config\x12\x16\n" + - "\x06domain\x18\x01 \x01(\tR\x06domainB\x9a\x01\n" + - "'com.xray.transport.internet.headers.dnsP\x01ZGgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/dns\xaa\x02#Xray.Transport.Internet.Headers.DNSb\x06proto3" - -var ( - file_transport_internet_headers_dns_config_proto_rawDescOnce sync.Once - file_transport_internet_headers_dns_config_proto_rawDescData []byte -) - -func file_transport_internet_headers_dns_config_proto_rawDescGZIP() []byte { - file_transport_internet_headers_dns_config_proto_rawDescOnce.Do(func() { - file_transport_internet_headers_dns_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_headers_dns_config_proto_rawDesc), len(file_transport_internet_headers_dns_config_proto_rawDesc))) - }) - return file_transport_internet_headers_dns_config_proto_rawDescData -} - -var file_transport_internet_headers_dns_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) -var file_transport_internet_headers_dns_config_proto_goTypes = []any{ - (*Config)(nil), // 0: xray.transport.internet.headers.dns.Config -} -var file_transport_internet_headers_dns_config_proto_depIdxs = []int32{ - 0, // [0:0] is the sub-list for method output_type - 0, // [0:0] is the sub-list for method input_type - 0, // [0:0] is the sub-list for extension type_name - 0, // [0:0] is the sub-list for extension extendee - 0, // [0:0] is the sub-list for field type_name -} - -func init() { file_transport_internet_headers_dns_config_proto_init() } -func file_transport_internet_headers_dns_config_proto_init() { - if File_transport_internet_headers_dns_config_proto != nil { - return - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_headers_dns_config_proto_rawDesc), len(file_transport_internet_headers_dns_config_proto_rawDesc)), - NumEnums: 0, - NumMessages: 1, - NumExtensions: 0, - NumServices: 0, - }, - GoTypes: file_transport_internet_headers_dns_config_proto_goTypes, - DependencyIndexes: file_transport_internet_headers_dns_config_proto_depIdxs, - MessageInfos: file_transport_internet_headers_dns_config_proto_msgTypes, - }.Build() - File_transport_internet_headers_dns_config_proto = out.File - file_transport_internet_headers_dns_config_proto_goTypes = nil - file_transport_internet_headers_dns_config_proto_depIdxs = nil -} diff --git a/transport/internet/headers/dns/config.proto b/transport/internet/headers/dns/config.proto deleted file mode 100644 index 9dbf7b87..00000000 --- a/transport/internet/headers/dns/config.proto +++ /dev/null @@ -1,12 +0,0 @@ -syntax = "proto3"; - -package xray.transport.internet.headers.dns; -option csharp_namespace = "Xray.Transport.Internet.Headers.DNS"; -option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/dns"; -option java_package = "com.xray.transport.internet.headers.dns"; -option java_multiple_files = true; - -message Config { - string domain = 1; -} - diff --git a/transport/internet/headers/dns/dns.go b/transport/internet/headers/dns/dns.go deleted file mode 100644 index c1b93462..00000000 --- a/transport/internet/headers/dns/dns.go +++ /dev/null @@ -1,123 +0,0 @@ -package dns - -import ( - "context" - "encoding/binary" - "errors" - - "github.com/amnezia-vpn/amnezia-xray-core/common" - "github.com/amnezia-vpn/amnezia-xray-core/common/dice" -) - -type DNS struct { - header []byte -} - -func (d DNS) Size() int32 { - return int32(len(d.header)) -} - -// Serialize implements PacketHeader. -func (d DNS) Serialize(b []byte) { - copy(b, d.header) - binary.BigEndian.PutUint16(b[0:], dice.RollUint16()) // random transaction ID -} - -// NewDNS returns a new DNS instance based on given config. -func NewDNS(ctx context.Context, config interface{}) (interface{}, error) { - var header []byte - - header = binary.BigEndian.AppendUint16(header, 0x0000) // Transaction ID - header = binary.BigEndian.AppendUint16(header, 0x0100) // Flags: Standard query - header = binary.BigEndian.AppendUint16(header, 0x0001) // Questions - header = binary.BigEndian.AppendUint16(header, 0x0000) // Answer RRs - header = binary.BigEndian.AppendUint16(header, 0x0000) // Authority RRs - header = binary.BigEndian.AppendUint16(header, 0x0000) // Additional RRs - - buf := make([]byte, 0x100) - - off1, err := packDomainName(config.(*Config).Domain+".", buf) - if err != nil { - return nil, err - } - - header = append(header, buf[:off1]...) - - header = binary.BigEndian.AppendUint16(header, 0x0001) // Type: A - header = binary.BigEndian.AppendUint16(header, 0x0001) // Class: IN - - return DNS{ - header: header, - }, nil -} - -// copied from github.com/miekg/dns -func packDomainName(s string, msg []byte) (off1 int, err error) { - off := 0 - ls := len(s) - // Each dot ends a segment of the name. - // We trade each dot byte for a length byte. - // Except for escaped dots (\.), which are normal dots. - // There is also a trailing zero. - - // Emit sequence of counted strings, chopping at dots. - var ( - begin int - bs []byte - ) - for i := 0; i < ls; i++ { - var c byte - if bs == nil { - c = s[i] - } else { - c = bs[i] - } - - switch c { - case '\\': - if off+1 > len(msg) { - return len(msg), errors.New("buffer size too small") - } - - if bs == nil { - bs = []byte(s) - } - - copy(bs[i:ls-1], bs[i+1:]) - ls-- - case '.': - labelLen := i - begin - if labelLen >= 1<<6 { // top two bits of length must be clear - return len(msg), errors.New("bad rdata") - } - - // off can already (we're in a loop) be bigger than len(msg) - // this happens when a name isn't fully qualified - if off+1+labelLen > len(msg) { - return len(msg), errors.New("buffer size too small") - } - - // The following is covered by the length check above. - msg[off] = byte(labelLen) - - if bs == nil { - copy(msg[off+1:], s[begin:i]) - } else { - copy(msg[off+1:], bs[begin:i]) - } - off += 1 + labelLen - begin = i + 1 - default: - } - } - - if off < len(msg) { - msg[off] = 0 - } - - return off + 1, nil -} - -func init() { - common.Must(common.RegisterConfig((*Config)(nil), NewDNS)) -} diff --git a/transport/internet/headers/srtp/config.pb.go b/transport/internet/headers/srtp/config.pb.go deleted file mode 100644 index b7097b6f..00000000 --- a/transport/internet/headers/srtp/config.pb.go +++ /dev/null @@ -1,169 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.36.11 -// protoc v3.21.12 -// source: transport/internet/headers/srtp/config.proto - -package srtp - -import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" - unsafe "unsafe" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -type Config struct { - state protoimpl.MessageState `protogen:"open.v1"` - Version uint32 `protobuf:"varint,1,opt,name=version,proto3" json:"version,omitempty"` - Padding bool `protobuf:"varint,2,opt,name=padding,proto3" json:"padding,omitempty"` - Extension bool `protobuf:"varint,3,opt,name=extension,proto3" json:"extension,omitempty"` - CsrcCount uint32 `protobuf:"varint,4,opt,name=csrc_count,json=csrcCount,proto3" json:"csrc_count,omitempty"` - Marker bool `protobuf:"varint,5,opt,name=marker,proto3" json:"marker,omitempty"` - PayloadType uint32 `protobuf:"varint,6,opt,name=payload_type,json=payloadType,proto3" json:"payload_type,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *Config) Reset() { - *x = Config{} - mi := &file_transport_internet_headers_srtp_config_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *Config) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*Config) ProtoMessage() {} - -func (x *Config) ProtoReflect() protoreflect.Message { - mi := &file_transport_internet_headers_srtp_config_proto_msgTypes[0] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use Config.ProtoReflect.Descriptor instead. -func (*Config) Descriptor() ([]byte, []int) { - return file_transport_internet_headers_srtp_config_proto_rawDescGZIP(), []int{0} -} - -func (x *Config) GetVersion() uint32 { - if x != nil { - return x.Version - } - return 0 -} - -func (x *Config) GetPadding() bool { - if x != nil { - return x.Padding - } - return false -} - -func (x *Config) GetExtension() bool { - if x != nil { - return x.Extension - } - return false -} - -func (x *Config) GetCsrcCount() uint32 { - if x != nil { - return x.CsrcCount - } - return 0 -} - -func (x *Config) GetMarker() bool { - if x != nil { - return x.Marker - } - return false -} - -func (x *Config) GetPayloadType() uint32 { - if x != nil { - return x.PayloadType - } - return 0 -} - -var File_transport_internet_headers_srtp_config_proto protoreflect.FileDescriptor - -const file_transport_internet_headers_srtp_config_proto_rawDesc = "" + - "\n" + - ",transport/internet/headers/srtp/config.proto\x12$xray.transport.internet.headers.srtp\"\xb4\x01\n" + - "\x06Config\x12\x18\n" + - "\aversion\x18\x01 \x01(\rR\aversion\x12\x18\n" + - "\apadding\x18\x02 \x01(\bR\apadding\x12\x1c\n" + - "\textension\x18\x03 \x01(\bR\textension\x12\x1d\n" + - "\n" + - "csrc_count\x18\x04 \x01(\rR\tcsrcCount\x12\x16\n" + - "\x06marker\x18\x05 \x01(\bR\x06marker\x12!\n" + - "\fpayload_type\x18\x06 \x01(\rR\vpayloadTypeB\x9d\x01\n" + - "(com.xray.transport.internet.headers.srtpP\x01ZHgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/srtp\xaa\x02$Xray.Transport.Internet.Headers.Srtpb\x06proto3" - -var ( - file_transport_internet_headers_srtp_config_proto_rawDescOnce sync.Once - file_transport_internet_headers_srtp_config_proto_rawDescData []byte -) - -func file_transport_internet_headers_srtp_config_proto_rawDescGZIP() []byte { - file_transport_internet_headers_srtp_config_proto_rawDescOnce.Do(func() { - file_transport_internet_headers_srtp_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_headers_srtp_config_proto_rawDesc), len(file_transport_internet_headers_srtp_config_proto_rawDesc))) - }) - return file_transport_internet_headers_srtp_config_proto_rawDescData -} - -var file_transport_internet_headers_srtp_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) -var file_transport_internet_headers_srtp_config_proto_goTypes = []any{ - (*Config)(nil), // 0: xray.transport.internet.headers.srtp.Config -} -var file_transport_internet_headers_srtp_config_proto_depIdxs = []int32{ - 0, // [0:0] is the sub-list for method output_type - 0, // [0:0] is the sub-list for method input_type - 0, // [0:0] is the sub-list for extension type_name - 0, // [0:0] is the sub-list for extension extendee - 0, // [0:0] is the sub-list for field type_name -} - -func init() { file_transport_internet_headers_srtp_config_proto_init() } -func file_transport_internet_headers_srtp_config_proto_init() { - if File_transport_internet_headers_srtp_config_proto != nil { - return - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_headers_srtp_config_proto_rawDesc), len(file_transport_internet_headers_srtp_config_proto_rawDesc)), - NumEnums: 0, - NumMessages: 1, - NumExtensions: 0, - NumServices: 0, - }, - GoTypes: file_transport_internet_headers_srtp_config_proto_goTypes, - DependencyIndexes: file_transport_internet_headers_srtp_config_proto_depIdxs, - MessageInfos: file_transport_internet_headers_srtp_config_proto_msgTypes, - }.Build() - File_transport_internet_headers_srtp_config_proto = out.File - file_transport_internet_headers_srtp_config_proto_goTypes = nil - file_transport_internet_headers_srtp_config_proto_depIdxs = nil -} diff --git a/transport/internet/headers/srtp/config.proto b/transport/internet/headers/srtp/config.proto deleted file mode 100644 index 8d8a122a..00000000 --- a/transport/internet/headers/srtp/config.proto +++ /dev/null @@ -1,16 +0,0 @@ -syntax = "proto3"; - -package xray.transport.internet.headers.srtp; -option csharp_namespace = "Xray.Transport.Internet.Headers.Srtp"; -option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/srtp"; -option java_package = "com.xray.transport.internet.headers.srtp"; -option java_multiple_files = true; - -message Config { - uint32 version = 1; - bool padding = 2; - bool extension = 3; - uint32 csrc_count = 4; - bool marker = 5; - uint32 payload_type = 6; -} diff --git a/transport/internet/headers/srtp/srtp.go b/transport/internet/headers/srtp/srtp.go deleted file mode 100644 index 390f48e5..00000000 --- a/transport/internet/headers/srtp/srtp.go +++ /dev/null @@ -1,37 +0,0 @@ -package srtp - -import ( - "context" - "encoding/binary" - - "github.com/amnezia-vpn/amnezia-xray-core/common" - "github.com/amnezia-vpn/amnezia-xray-core/common/dice" -) - -type SRTP struct { - header uint16 - number uint16 -} - -func (*SRTP) Size() int32 { - return 4 -} - -// Serialize implements PacketHeader. -func (s *SRTP) Serialize(b []byte) { - s.number++ - binary.BigEndian.PutUint16(b, s.header) - binary.BigEndian.PutUint16(b[2:], s.number) -} - -// New returns a new SRTP instance based on the given config. -func New(ctx context.Context, config interface{}) (interface{}, error) { - return &SRTP{ - header: 0xB5E8, - number: dice.RollUint16(), - }, nil -} - -func init() { - common.Must(common.RegisterConfig((*Config)(nil), New)) -} diff --git a/transport/internet/headers/srtp/srtp_test.go b/transport/internet/headers/srtp/srtp_test.go deleted file mode 100644 index 5fe00f11..00000000 --- a/transport/internet/headers/srtp/srtp_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package srtp_test - -import ( - "context" - "testing" - - "github.com/amnezia-vpn/amnezia-xray-core/common" - "github.com/amnezia-vpn/amnezia-xray-core/common/buf" - . "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/srtp" -) - -func TestSRTPWrite(t *testing.T) { - content := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g'} - srtpRaw, err := New(context.Background(), &Config{}) - common.Must(err) - - srtp := srtpRaw.(*SRTP) - - payload := buf.New() - srtp.Serialize(payload.Extend(srtp.Size())) - payload.Write(content) - - expectedLen := int32(len(content)) + srtp.Size() - if payload.Len() != expectedLen { - t.Error("expected ", expectedLen, " of bytes, but got ", payload.Len()) - } -} diff --git a/transport/internet/headers/tls/config.pb.go b/transport/internet/headers/tls/config.pb.go deleted file mode 100644 index 0c4b7755..00000000 --- a/transport/internet/headers/tls/config.pb.go +++ /dev/null @@ -1,114 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.36.11 -// protoc v3.21.12 -// source: transport/internet/headers/tls/config.proto - -package tls - -import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" - unsafe "unsafe" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -type PacketConfig struct { - state protoimpl.MessageState `protogen:"open.v1"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *PacketConfig) Reset() { - *x = PacketConfig{} - mi := &file_transport_internet_headers_tls_config_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *PacketConfig) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*PacketConfig) ProtoMessage() {} - -func (x *PacketConfig) ProtoReflect() protoreflect.Message { - mi := &file_transport_internet_headers_tls_config_proto_msgTypes[0] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use PacketConfig.ProtoReflect.Descriptor instead. -func (*PacketConfig) Descriptor() ([]byte, []int) { - return file_transport_internet_headers_tls_config_proto_rawDescGZIP(), []int{0} -} - -var File_transport_internet_headers_tls_config_proto protoreflect.FileDescriptor - -const file_transport_internet_headers_tls_config_proto_rawDesc = "" + - "\n" + - "+transport/internet/headers/tls/config.proto\x12#xray.transport.internet.headers.tls\"\x0e\n" + - "\fPacketConfigB\x9a\x01\n" + - "'com.xray.transport.internet.headers.tlsP\x01ZGgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/tls\xaa\x02#Xray.Transport.Internet.Headers.Tlsb\x06proto3" - -var ( - file_transport_internet_headers_tls_config_proto_rawDescOnce sync.Once - file_transport_internet_headers_tls_config_proto_rawDescData []byte -) - -func file_transport_internet_headers_tls_config_proto_rawDescGZIP() []byte { - file_transport_internet_headers_tls_config_proto_rawDescOnce.Do(func() { - file_transport_internet_headers_tls_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_headers_tls_config_proto_rawDesc), len(file_transport_internet_headers_tls_config_proto_rawDesc))) - }) - return file_transport_internet_headers_tls_config_proto_rawDescData -} - -var file_transport_internet_headers_tls_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) -var file_transport_internet_headers_tls_config_proto_goTypes = []any{ - (*PacketConfig)(nil), // 0: xray.transport.internet.headers.tls.PacketConfig -} -var file_transport_internet_headers_tls_config_proto_depIdxs = []int32{ - 0, // [0:0] is the sub-list for method output_type - 0, // [0:0] is the sub-list for method input_type - 0, // [0:0] is the sub-list for extension type_name - 0, // [0:0] is the sub-list for extension extendee - 0, // [0:0] is the sub-list for field type_name -} - -func init() { file_transport_internet_headers_tls_config_proto_init() } -func file_transport_internet_headers_tls_config_proto_init() { - if File_transport_internet_headers_tls_config_proto != nil { - return - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_headers_tls_config_proto_rawDesc), len(file_transport_internet_headers_tls_config_proto_rawDesc)), - NumEnums: 0, - NumMessages: 1, - NumExtensions: 0, - NumServices: 0, - }, - GoTypes: file_transport_internet_headers_tls_config_proto_goTypes, - DependencyIndexes: file_transport_internet_headers_tls_config_proto_depIdxs, - MessageInfos: file_transport_internet_headers_tls_config_proto_msgTypes, - }.Build() - File_transport_internet_headers_tls_config_proto = out.File - file_transport_internet_headers_tls_config_proto_goTypes = nil - file_transport_internet_headers_tls_config_proto_depIdxs = nil -} diff --git a/transport/internet/headers/tls/config.proto b/transport/internet/headers/tls/config.proto deleted file mode 100644 index c094b4b6..00000000 --- a/transport/internet/headers/tls/config.proto +++ /dev/null @@ -1,9 +0,0 @@ -syntax = "proto3"; - -package xray.transport.internet.headers.tls; -option csharp_namespace = "Xray.Transport.Internet.Headers.Tls"; -option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/tls"; -option java_package = "com.xray.transport.internet.headers.tls"; -option java_multiple_files = true; - -message PacketConfig {} diff --git a/transport/internet/headers/tls/dtls.go b/transport/internet/headers/tls/dtls.go deleted file mode 100644 index 9398f6a9..00000000 --- a/transport/internet/headers/tls/dtls.go +++ /dev/null @@ -1,55 +0,0 @@ -package tls - -import ( - "context" - - "github.com/amnezia-vpn/amnezia-xray-core/common" - "github.com/amnezia-vpn/amnezia-xray-core/common/dice" -) - -// DTLS writes header as DTLS. See https://tools.ietf.org/html/rfc6347 -type DTLS struct { - epoch uint16 - length uint16 - sequence uint32 -} - -// Size implements PacketHeader. -func (*DTLS) Size() int32 { - return 1 + 2 + 2 + 6 + 2 -} - -// Serialize implements PacketHeader. -func (d *DTLS) Serialize(b []byte) { - b[0] = 23 // application data - b[1] = 254 - b[2] = 253 - b[3] = byte(d.epoch >> 8) - b[4] = byte(d.epoch) - b[5] = 0 - b[6] = 0 - b[7] = byte(d.sequence >> 24) - b[8] = byte(d.sequence >> 16) - b[9] = byte(d.sequence >> 8) - b[10] = byte(d.sequence) - d.sequence++ - b[11] = byte(d.length >> 8) - b[12] = byte(d.length) - d.length += 17 - if d.length > 100 { - d.length -= 50 - } -} - -// New creates a new UTP header for the given config. -func New(ctx context.Context, config interface{}) (interface{}, error) { - return &DTLS{ - epoch: dice.RollUint16(), - sequence: 0, - length: 17, - }, nil -} - -func init() { - common.Must(common.RegisterConfig((*PacketConfig)(nil), New)) -} diff --git a/transport/internet/headers/tls/dtls_test.go b/transport/internet/headers/tls/dtls_test.go deleted file mode 100644 index 20300932..00000000 --- a/transport/internet/headers/tls/dtls_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package tls_test - -import ( - "context" - "testing" - - "github.com/amnezia-vpn/amnezia-xray-core/common" - "github.com/amnezia-vpn/amnezia-xray-core/common/buf" - . "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/tls" -) - -func TestDTLSWrite(t *testing.T) { - content := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g'} - dtlsRaw, err := New(context.Background(), &PacketConfig{}) - common.Must(err) - - dtls := dtlsRaw.(*DTLS) - - payload := buf.New() - dtls.Serialize(payload.Extend(dtls.Size())) - payload.Write(content) - - if payload.Len() != int32(len(content))+dtls.Size() { - t.Error("payload len: ", payload.Len(), " want ", int32(len(content))+dtls.Size()) - } -} diff --git a/transport/internet/headers/utp/config.pb.go b/transport/internet/headers/utp/config.pb.go deleted file mode 100644 index 26943386..00000000 --- a/transport/internet/headers/utp/config.pb.go +++ /dev/null @@ -1,123 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.36.11 -// protoc v3.21.12 -// source: transport/internet/headers/utp/config.proto - -package utp - -import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" - unsafe "unsafe" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -type Config struct { - state protoimpl.MessageState `protogen:"open.v1"` - Version uint32 `protobuf:"varint,1,opt,name=version,proto3" json:"version,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *Config) Reset() { - *x = Config{} - mi := &file_transport_internet_headers_utp_config_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *Config) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*Config) ProtoMessage() {} - -func (x *Config) ProtoReflect() protoreflect.Message { - mi := &file_transport_internet_headers_utp_config_proto_msgTypes[0] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use Config.ProtoReflect.Descriptor instead. -func (*Config) Descriptor() ([]byte, []int) { - return file_transport_internet_headers_utp_config_proto_rawDescGZIP(), []int{0} -} - -func (x *Config) GetVersion() uint32 { - if x != nil { - return x.Version - } - return 0 -} - -var File_transport_internet_headers_utp_config_proto protoreflect.FileDescriptor - -const file_transport_internet_headers_utp_config_proto_rawDesc = "" + - "\n" + - "+transport/internet/headers/utp/config.proto\x12#xray.transport.internet.headers.utp\"\"\n" + - "\x06Config\x12\x18\n" + - "\aversion\x18\x01 \x01(\rR\aversionB\x9a\x01\n" + - "'com.xray.transport.internet.headers.utpP\x01ZGgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/utp\xaa\x02#Xray.Transport.Internet.Headers.Utpb\x06proto3" - -var ( - file_transport_internet_headers_utp_config_proto_rawDescOnce sync.Once - file_transport_internet_headers_utp_config_proto_rawDescData []byte -) - -func file_transport_internet_headers_utp_config_proto_rawDescGZIP() []byte { - file_transport_internet_headers_utp_config_proto_rawDescOnce.Do(func() { - file_transport_internet_headers_utp_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_headers_utp_config_proto_rawDesc), len(file_transport_internet_headers_utp_config_proto_rawDesc))) - }) - return file_transport_internet_headers_utp_config_proto_rawDescData -} - -var file_transport_internet_headers_utp_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) -var file_transport_internet_headers_utp_config_proto_goTypes = []any{ - (*Config)(nil), // 0: xray.transport.internet.headers.utp.Config -} -var file_transport_internet_headers_utp_config_proto_depIdxs = []int32{ - 0, // [0:0] is the sub-list for method output_type - 0, // [0:0] is the sub-list for method input_type - 0, // [0:0] is the sub-list for extension type_name - 0, // [0:0] is the sub-list for extension extendee - 0, // [0:0] is the sub-list for field type_name -} - -func init() { file_transport_internet_headers_utp_config_proto_init() } -func file_transport_internet_headers_utp_config_proto_init() { - if File_transport_internet_headers_utp_config_proto != nil { - return - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_headers_utp_config_proto_rawDesc), len(file_transport_internet_headers_utp_config_proto_rawDesc)), - NumEnums: 0, - NumMessages: 1, - NumExtensions: 0, - NumServices: 0, - }, - GoTypes: file_transport_internet_headers_utp_config_proto_goTypes, - DependencyIndexes: file_transport_internet_headers_utp_config_proto_depIdxs, - MessageInfos: file_transport_internet_headers_utp_config_proto_msgTypes, - }.Build() - File_transport_internet_headers_utp_config_proto = out.File - file_transport_internet_headers_utp_config_proto_goTypes = nil - file_transport_internet_headers_utp_config_proto_depIdxs = nil -} diff --git a/transport/internet/headers/utp/config.proto b/transport/internet/headers/utp/config.proto deleted file mode 100644 index b9432a40..00000000 --- a/transport/internet/headers/utp/config.proto +++ /dev/null @@ -1,11 +0,0 @@ -syntax = "proto3"; - -package xray.transport.internet.headers.utp; -option csharp_namespace = "Xray.Transport.Internet.Headers.Utp"; -option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/utp"; -option java_package = "com.xray.transport.internet.headers.utp"; -option java_multiple_files = true; - -message Config { - uint32 version = 1; -} diff --git a/transport/internet/headers/utp/utp.go b/transport/internet/headers/utp/utp.go deleted file mode 100644 index 2d28684f..00000000 --- a/transport/internet/headers/utp/utp.go +++ /dev/null @@ -1,39 +0,0 @@ -package utp - -import ( - "context" - "encoding/binary" - - "github.com/amnezia-vpn/amnezia-xray-core/common" - "github.com/amnezia-vpn/amnezia-xray-core/common/dice" -) - -type UTP struct { - header byte - extension byte - connectionID uint16 -} - -func (*UTP) Size() int32 { - return 4 -} - -// Serialize implements PacketHeader. -func (u *UTP) Serialize(b []byte) { - binary.BigEndian.PutUint16(b, u.connectionID) - b[2] = u.header - b[3] = u.extension -} - -// New creates a new UTP header for the given config. -func New(ctx context.Context, config interface{}) (interface{}, error) { - return &UTP{ - header: 1, - extension: 0, - connectionID: dice.RollUint16(), - }, nil -} - -func init() { - common.Must(common.RegisterConfig((*Config)(nil), New)) -} diff --git a/transport/internet/headers/utp/utp_test.go b/transport/internet/headers/utp/utp_test.go deleted file mode 100644 index b5ca420c..00000000 --- a/transport/internet/headers/utp/utp_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package utp_test - -import ( - "context" - "testing" - - "github.com/amnezia-vpn/amnezia-xray-core/common" - "github.com/amnezia-vpn/amnezia-xray-core/common/buf" - . "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/utp" -) - -func TestUTPWrite(t *testing.T) { - content := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g'} - utpRaw, err := New(context.Background(), &Config{}) - common.Must(err) - - utp := utpRaw.(*UTP) - - payload := buf.New() - utp.Serialize(payload.Extend(utp.Size())) - payload.Write(content) - - if payload.Len() != int32(len(content))+utp.Size() { - t.Error("unexpected payload length: ", payload.Len()) - } -} diff --git a/transport/internet/headers/wechat/config.pb.go b/transport/internet/headers/wechat/config.pb.go deleted file mode 100644 index a91908c2..00000000 --- a/transport/internet/headers/wechat/config.pb.go +++ /dev/null @@ -1,114 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.36.11 -// protoc v3.21.12 -// source: transport/internet/headers/wechat/config.proto - -package wechat - -import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" - unsafe "unsafe" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -type VideoConfig struct { - state protoimpl.MessageState `protogen:"open.v1"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *VideoConfig) Reset() { - *x = VideoConfig{} - mi := &file_transport_internet_headers_wechat_config_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *VideoConfig) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*VideoConfig) ProtoMessage() {} - -func (x *VideoConfig) ProtoReflect() protoreflect.Message { - mi := &file_transport_internet_headers_wechat_config_proto_msgTypes[0] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use VideoConfig.ProtoReflect.Descriptor instead. -func (*VideoConfig) Descriptor() ([]byte, []int) { - return file_transport_internet_headers_wechat_config_proto_rawDescGZIP(), []int{0} -} - -var File_transport_internet_headers_wechat_config_proto protoreflect.FileDescriptor - -const file_transport_internet_headers_wechat_config_proto_rawDesc = "" + - "\n" + - ".transport/internet/headers/wechat/config.proto\x12&xray.transport.internet.headers.wechat\"\r\n" + - "\vVideoConfigB\xa3\x01\n" + - "*com.xray.transport.internet.headers.wechatP\x01ZJgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/wechat\xaa\x02&Xray.Transport.Internet.Headers.Wechatb\x06proto3" - -var ( - file_transport_internet_headers_wechat_config_proto_rawDescOnce sync.Once - file_transport_internet_headers_wechat_config_proto_rawDescData []byte -) - -func file_transport_internet_headers_wechat_config_proto_rawDescGZIP() []byte { - file_transport_internet_headers_wechat_config_proto_rawDescOnce.Do(func() { - file_transport_internet_headers_wechat_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_headers_wechat_config_proto_rawDesc), len(file_transport_internet_headers_wechat_config_proto_rawDesc))) - }) - return file_transport_internet_headers_wechat_config_proto_rawDescData -} - -var file_transport_internet_headers_wechat_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) -var file_transport_internet_headers_wechat_config_proto_goTypes = []any{ - (*VideoConfig)(nil), // 0: xray.transport.internet.headers.wechat.VideoConfig -} -var file_transport_internet_headers_wechat_config_proto_depIdxs = []int32{ - 0, // [0:0] is the sub-list for method output_type - 0, // [0:0] is the sub-list for method input_type - 0, // [0:0] is the sub-list for extension type_name - 0, // [0:0] is the sub-list for extension extendee - 0, // [0:0] is the sub-list for field type_name -} - -func init() { file_transport_internet_headers_wechat_config_proto_init() } -func file_transport_internet_headers_wechat_config_proto_init() { - if File_transport_internet_headers_wechat_config_proto != nil { - return - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_headers_wechat_config_proto_rawDesc), len(file_transport_internet_headers_wechat_config_proto_rawDesc)), - NumEnums: 0, - NumMessages: 1, - NumExtensions: 0, - NumServices: 0, - }, - GoTypes: file_transport_internet_headers_wechat_config_proto_goTypes, - DependencyIndexes: file_transport_internet_headers_wechat_config_proto_depIdxs, - MessageInfos: file_transport_internet_headers_wechat_config_proto_msgTypes, - }.Build() - File_transport_internet_headers_wechat_config_proto = out.File - file_transport_internet_headers_wechat_config_proto_goTypes = nil - file_transport_internet_headers_wechat_config_proto_depIdxs = nil -} diff --git a/transport/internet/headers/wechat/config.proto b/transport/internet/headers/wechat/config.proto deleted file mode 100644 index 35fd019c..00000000 --- a/transport/internet/headers/wechat/config.proto +++ /dev/null @@ -1,9 +0,0 @@ -syntax = "proto3"; - -package xray.transport.internet.headers.wechat; -option csharp_namespace = "Xray.Transport.Internet.Headers.Wechat"; -option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/wechat"; -option java_package = "com.xray.transport.internet.headers.wechat"; -option java_multiple_files = true; - -message VideoConfig {} diff --git a/transport/internet/headers/wechat/wechat.go b/transport/internet/headers/wechat/wechat.go deleted file mode 100644 index 888a76a1..00000000 --- a/transport/internet/headers/wechat/wechat.go +++ /dev/null @@ -1,43 +0,0 @@ -package wechat - -import ( - "context" - "encoding/binary" - - "github.com/amnezia-vpn/amnezia-xray-core/common" - "github.com/amnezia-vpn/amnezia-xray-core/common/dice" -) - -type VideoChat struct { - sn uint32 -} - -func (vc *VideoChat) Size() int32 { - return 13 -} - -// Serialize implements PacketHeader. -func (vc *VideoChat) Serialize(b []byte) { - vc.sn++ - b[0] = 0xa1 - b[1] = 0x08 - binary.BigEndian.PutUint32(b[2:], vc.sn) // b[2:6] - b[6] = 0x00 - b[7] = 0x10 - b[8] = 0x11 - b[9] = 0x18 - b[10] = 0x30 - b[11] = 0x22 - b[12] = 0x30 -} - -// NewVideoChat returns a new VideoChat instance based on given config. -func NewVideoChat(ctx context.Context, config interface{}) (interface{}, error) { - return &VideoChat{ - sn: uint32(dice.RollUint16()), - }, nil -} - -func init() { - common.Must(common.RegisterConfig((*VideoConfig)(nil), NewVideoChat)) -} diff --git a/transport/internet/headers/wechat/wechat_test.go b/transport/internet/headers/wechat/wechat_test.go deleted file mode 100644 index 7755c9bd..00000000 --- a/transport/internet/headers/wechat/wechat_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package wechat_test - -import ( - "context" - "testing" - - "github.com/amnezia-vpn/amnezia-xray-core/common" - "github.com/amnezia-vpn/amnezia-xray-core/common/buf" - . "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/wechat" -) - -func TestUTPWrite(t *testing.T) { - videoRaw, err := NewVideoChat(context.Background(), &VideoConfig{}) - common.Must(err) - - video := videoRaw.(*VideoChat) - - payload := buf.New() - video.Serialize(payload.Extend(video.Size())) - - if payload.Len() != video.Size() { - t.Error("expected payload size ", video.Size(), " but got ", payload.Len()) - } -} diff --git a/transport/internet/headers/wireguard/config.pb.go b/transport/internet/headers/wireguard/config.pb.go deleted file mode 100644 index ffae8a47..00000000 --- a/transport/internet/headers/wireguard/config.pb.go +++ /dev/null @@ -1,114 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.36.11 -// protoc v3.21.12 -// source: transport/internet/headers/wireguard/config.proto - -package wireguard - -import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" - unsafe "unsafe" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -type WireguardConfig struct { - state protoimpl.MessageState `protogen:"open.v1"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *WireguardConfig) Reset() { - *x = WireguardConfig{} - mi := &file_transport_internet_headers_wireguard_config_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *WireguardConfig) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*WireguardConfig) ProtoMessage() {} - -func (x *WireguardConfig) ProtoReflect() protoreflect.Message { - mi := &file_transport_internet_headers_wireguard_config_proto_msgTypes[0] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use WireguardConfig.ProtoReflect.Descriptor instead. -func (*WireguardConfig) Descriptor() ([]byte, []int) { - return file_transport_internet_headers_wireguard_config_proto_rawDescGZIP(), []int{0} -} - -var File_transport_internet_headers_wireguard_config_proto protoreflect.FileDescriptor - -const file_transport_internet_headers_wireguard_config_proto_rawDesc = "" + - "\n" + - "1transport/internet/headers/wireguard/config.proto\x12)xray.transport.internet.headers.wireguard\"\x11\n" + - "\x0fWireguardConfigB\xac\x01\n" + - "-com.xray.transport.internet.headers.wireguardP\x01ZMgithub.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/wireguard\xaa\x02)Xray.Transport.Internet.Headers.Wireguardb\x06proto3" - -var ( - file_transport_internet_headers_wireguard_config_proto_rawDescOnce sync.Once - file_transport_internet_headers_wireguard_config_proto_rawDescData []byte -) - -func file_transport_internet_headers_wireguard_config_proto_rawDescGZIP() []byte { - file_transport_internet_headers_wireguard_config_proto_rawDescOnce.Do(func() { - file_transport_internet_headers_wireguard_config_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_transport_internet_headers_wireguard_config_proto_rawDesc), len(file_transport_internet_headers_wireguard_config_proto_rawDesc))) - }) - return file_transport_internet_headers_wireguard_config_proto_rawDescData -} - -var file_transport_internet_headers_wireguard_config_proto_msgTypes = make([]protoimpl.MessageInfo, 1) -var file_transport_internet_headers_wireguard_config_proto_goTypes = []any{ - (*WireguardConfig)(nil), // 0: xray.transport.internet.headers.wireguard.WireguardConfig -} -var file_transport_internet_headers_wireguard_config_proto_depIdxs = []int32{ - 0, // [0:0] is the sub-list for method output_type - 0, // [0:0] is the sub-list for method input_type - 0, // [0:0] is the sub-list for extension type_name - 0, // [0:0] is the sub-list for extension extendee - 0, // [0:0] is the sub-list for field type_name -} - -func init() { file_transport_internet_headers_wireguard_config_proto_init() } -func file_transport_internet_headers_wireguard_config_proto_init() { - if File_transport_internet_headers_wireguard_config_proto != nil { - return - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: unsafe.Slice(unsafe.StringData(file_transport_internet_headers_wireguard_config_proto_rawDesc), len(file_transport_internet_headers_wireguard_config_proto_rawDesc)), - NumEnums: 0, - NumMessages: 1, - NumExtensions: 0, - NumServices: 0, - }, - GoTypes: file_transport_internet_headers_wireguard_config_proto_goTypes, - DependencyIndexes: file_transport_internet_headers_wireguard_config_proto_depIdxs, - MessageInfos: file_transport_internet_headers_wireguard_config_proto_msgTypes, - }.Build() - File_transport_internet_headers_wireguard_config_proto = out.File - file_transport_internet_headers_wireguard_config_proto_goTypes = nil - file_transport_internet_headers_wireguard_config_proto_depIdxs = nil -} diff --git a/transport/internet/headers/wireguard/config.proto b/transport/internet/headers/wireguard/config.proto deleted file mode 100644 index 07c5cd6c..00000000 --- a/transport/internet/headers/wireguard/config.proto +++ /dev/null @@ -1,9 +0,0 @@ -syntax = "proto3"; - -package xray.transport.internet.headers.wireguard; -option csharp_namespace = "Xray.Transport.Internet.Headers.Wireguard"; -option go_package = "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/headers/wireguard"; -option java_package = "com.xray.transport.internet.headers.wireguard"; -option java_multiple_files = true; - -message WireguardConfig {} diff --git a/transport/internet/headers/wireguard/wireguard.go b/transport/internet/headers/wireguard/wireguard.go deleted file mode 100644 index b9d198bc..00000000 --- a/transport/internet/headers/wireguard/wireguard.go +++ /dev/null @@ -1,30 +0,0 @@ -package wireguard - -import ( - "context" - - "github.com/amnezia-vpn/amnezia-xray-core/common" -) - -type Wireguard struct{} - -func (Wireguard) Size() int32 { - return 4 -} - -// Serialize implements PacketHeader. -func (Wireguard) Serialize(b []byte) { - b[0] = 0x04 - b[1] = 0x00 - b[2] = 0x00 - b[3] = 0x00 -} - -// NewWireguard returns a new VideoChat instance based on given config. -func NewWireguard(ctx context.Context, config interface{}) (interface{}, error) { - return Wireguard{}, nil -} - -func init() { - common.Must(common.RegisterConfig((*WireguardConfig)(nil), NewWireguard)) -} diff --git a/transport/internet/httpupgrade/dialer.go b/transport/internet/httpupgrade/dialer.go index 35661bdf..ecb0d238 100644 --- a/transport/internet/httpupgrade/dialer.go +++ b/transport/internet/httpupgrade/dialer.go @@ -10,6 +10,7 @@ import ( "github.com/amnezia-vpn/amnezia-xray-core/common" "github.com/amnezia-vpn/amnezia-xray-core/common/errors" "github.com/amnezia-vpn/amnezia-xray-core/common/net" + "github.com/amnezia-vpn/amnezia-xray-core/common/utils" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/stat" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/tls" @@ -86,6 +87,9 @@ func dialhttpUpgrade(ctx context.Context, dest net.Destination, streamSettings * for key, value := range transportConfiguration.Header { AddHeader(req.Header, key, value) } + if req.Header.Get("User-Agent") == "" { + req.Header.Set("User-Agent", utils.ChromeUA) + } req.Header.Set("Connection", "Upgrade") req.Header.Set("Upgrade", "websocket") diff --git a/transport/internet/hysteria/dialer.go b/transport/internet/hysteria/dialer.go index edd6bd85..4c1098ea 100644 --- a/transport/internet/hysteria/dialer.go +++ b/transport/internet/hysteria/dialer.go @@ -176,6 +176,7 @@ func (c *client) dial() error { } pktConn, err = udphop.NewUDPHopPacketConn(addr, c.config.IntervalMin, c.config.IntervalMax, c.udphopDialer, pktConn, index) if err != nil { + raw.Close() return errors.New("udphop err").Base(err) } } @@ -183,6 +184,7 @@ func (c *client) dial() error { if c.udpmaskManager != nil { pktConn, err = c.udpmaskManager.WrapPacketConnClient(pktConn) if err != nil { + raw.Close() return errors.New("mask err").Base(err) } } diff --git a/transport/internet/kcp/config.go b/transport/internet/kcp/config.go index 6a617fb1..ad62b44a 100644 --- a/transport/internet/kcp/config.go +++ b/transport/internet/kcp/config.go @@ -1,8 +1,6 @@ package kcp import ( - "crypto/cipher" - "github.com/amnezia-vpn/amnezia-xray-core/common" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet" ) @@ -48,32 +46,12 @@ func (c *Config) GetWriteBufferSize() uint32 { } // GetReadBufferSize returns the size of ReadBuffer in bytes. -func (c *Config) GetReadBufferSize() uint32 { - if c == nil || c.ReadBuffer == nil { - return 2 * 1024 * 1024 - } - return c.ReadBuffer.Size -} - -// GetSecurity returns the security settings. -func (c *Config) GetSecurity() (cipher.AEAD, error) { - if c.Seed != nil { - return NewAEADAESGCMBasedOnSeed(c.Seed.Seed), nil - } - return NewSimpleAuthenticator(), nil -} - -func (c *Config) GetPackerHeader() (internet.PacketHeader, error) { - if c.HeaderConfig != nil { - rawConfig, err := c.HeaderConfig.GetInstance() - if err != nil { - return nil, err - } - - return internet.CreatePacketHeader(rawConfig) - } - return nil, nil -} +// func (c *Config) GetReadBufferSize() uint32 { +// if c == nil || c.ReadBuffer == nil { +// return 2 * 1024 * 1024 +// } +// return c.ReadBuffer.Size +// } func (c *Config) GetSendingInFlightSize() uint32 { size := c.GetUplinkCapacityValue() * 1024 * 1024 / c.GetMTUValue() / (1000 / c.GetTTIValue()) @@ -95,9 +73,9 @@ func (c *Config) GetReceivingInFlightSize() uint32 { return size } -func (c *Config) GetReceivingBufferSize() uint32 { - return c.GetReadBufferSize() / c.GetMTUValue() -} +// func (c *Config) GetReceivingBufferSize() uint32 { +// return c.GetReadBufferSize() / c.GetMTUValue() +// } func init() { common.Must(internet.RegisterProtocolConfigCreator(protocolName, func() interface{} { diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index a5b7091f..38f2579d 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -204,7 +204,7 @@ type Connection struct { } // NewConnection create a new KCP connection between local and remote. -func NewConnection(meta ConnMetadata, writer PacketWriter, closer io.Closer, config *Config) *Connection { +func NewConnection(meta ConnMetadata, writer io.Writer, closer io.Closer, config *Config) *Connection { errors.LogInfo(context.Background(), "#", meta.Conversation, " creating connection to ", meta.RemoteAddr) conn := &Connection{ @@ -215,7 +215,7 @@ func NewConnection(meta ConnMetadata, writer PacketWriter, closer io.Closer, con dataOutput: signal.NewNotifier(), Config: config, output: NewRetryableWriter(NewSegmentWriter(writer)), - mss: config.GetMTUValue() - uint32(writer.Overhead()) - DataSegmentOverhead, + mss: config.GetMTUValue() - DataSegmentOverhead, roundTrip: &RoundTripInfo{ rto: 100, minRtt: config.GetTTIValue(), diff --git a/transport/internet/kcp/connection_test.go b/transport/internet/kcp/connection_test.go index 35bcfdd2..8b381ff9 100644 --- a/transport/internet/kcp/connection_test.go +++ b/transport/internet/kcp/connection_test.go @@ -16,9 +16,7 @@ func (NoOpCloser) Close() error { } func TestConnectionReadTimeout(t *testing.T) { - conn := NewConnection(ConnMetadata{Conversation: 1}, &KCPPacketWriter{ - Writer: buf.DiscardBytes, - }, NoOpCloser(0), &Config{}) + conn := NewConnection(ConnMetadata{Conversation: 1}, buf.DiscardBytes, NoOpCloser(0), &Config{}) conn.SetReadDeadline(time.Now().Add(time.Second)) b := make([]byte, 1024) diff --git a/transport/internet/kcp/crypt.go b/transport/internet/kcp/crypt.go deleted file mode 100644 index c100bc33..00000000 --- a/transport/internet/kcp/crypt.go +++ /dev/null @@ -1,77 +0,0 @@ -package kcp - -import ( - "crypto/cipher" - "encoding/binary" - "hash/fnv" - - "github.com/amnezia-vpn/amnezia-xray-core/common" - "github.com/amnezia-vpn/amnezia-xray-core/common/errors" -) - -// SimpleAuthenticator is a legacy AEAD used for KCP encryption. -type SimpleAuthenticator struct{} - -// NewSimpleAuthenticator creates a new SimpleAuthenticator -func NewSimpleAuthenticator() cipher.AEAD { - return &SimpleAuthenticator{} -} - -// NonceSize implements cipher.AEAD.NonceSize(). -func (*SimpleAuthenticator) NonceSize() int { - return 0 -} - -// Overhead implements cipher.AEAD.NonceSize(). -func (*SimpleAuthenticator) Overhead() int { - return 6 -} - -// Seal implements cipher.AEAD.Seal(). -func (a *SimpleAuthenticator) Seal(dst, nonce, plain, extra []byte) []byte { - dst = append(dst, 0, 0, 0, 0, 0, 0) // 4 bytes for hash, and then 2 bytes for length - binary.BigEndian.PutUint16(dst[4:], uint16(len(plain))) - dst = append(dst, plain...) - - fnvHash := fnv.New32a() - common.Must2(fnvHash.Write(dst[4:])) - fnvHash.Sum(dst[:0]) - - dstLen := len(dst) - xtra := 4 - dstLen%4 - if xtra != 4 { - dst = append(dst, make([]byte, xtra)...) - } - xorfwd(dst) - if xtra != 4 { - dst = dst[:dstLen] - } - return dst -} - -// Open implements cipher.AEAD.Open(). -func (a *SimpleAuthenticator) Open(dst, nonce, cipherText, extra []byte) ([]byte, error) { - dst = append(dst, cipherText...) - dstLen := len(dst) - xtra := 4 - dstLen%4 - if xtra != 4 { - dst = append(dst, make([]byte, xtra)...) - } - xorbkd(dst) - if xtra != 4 { - dst = dst[:dstLen] - } - - fnvHash := fnv.New32a() - common.Must2(fnvHash.Write(dst[4:])) - if binary.BigEndian.Uint32(dst[:4]) != fnvHash.Sum32() { - return nil, errors.New("invalid auth") - } - - length := binary.BigEndian.Uint16(dst[4:6]) - if len(dst)-6 != int(length) { - return nil, errors.New("invalid auth") - } - - return dst[6:], nil -} diff --git a/transport/internet/kcp/crypt_test.go b/transport/internet/kcp/crypt_test.go deleted file mode 100644 index 538f9f9e..00000000 --- a/transport/internet/kcp/crypt_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package kcp_test - -import ( - "testing" - - "github.com/amnezia-vpn/amnezia-xray-core/common" - . "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/kcp" - "github.com/google/go-cmp/cmp" -) - -func TestSimpleAuthenticator(t *testing.T) { - cache := make([]byte, 512) - - payload := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g'} - - auth := NewSimpleAuthenticator() - b := auth.Seal(cache[:0], nil, payload, nil) - c, err := auth.Open(cache[:0], nil, b, nil) - common.Must(err) - if r := cmp.Diff(c, payload); r != "" { - t.Error(r) - } -} - -func TestSimpleAuthenticator2(t *testing.T) { - cache := make([]byte, 512) - - payload := []byte{'a', 'b'} - - auth := NewSimpleAuthenticator() - b := auth.Seal(cache[:0], nil, payload, nil) - c, err := auth.Open(cache[:0], nil, b, nil) - common.Must(err) - if r := cmp.Diff(c, payload); r != "" { - t.Error(r) - } -} diff --git a/transport/internet/kcp/cryptreal.go b/transport/internet/kcp/cryptreal.go deleted file mode 100644 index 5f434065..00000000 --- a/transport/internet/kcp/cryptreal.go +++ /dev/null @@ -1,13 +0,0 @@ -package kcp - -import ( - "crypto/cipher" - "crypto/sha256" - - "github.com/amnezia-vpn/amnezia-xray-core/common/crypto" -) - -func NewAEADAESGCMBasedOnSeed(seed string) cipher.AEAD { - hashedSeed := sha256.Sum256([]byte(seed)) - return crypto.NewAesGcm(hashedSeed[:16]) -} diff --git a/transport/internet/kcp/dialer.go b/transport/internet/kcp/dialer.go index bf9190ee..18e32720 100644 --- a/transport/internet/kcp/dialer.go +++ b/transport/internet/kcp/dialer.go @@ -54,32 +54,32 @@ func DialKCP(ctx context.Context, dest net.Destination, streamSettings *internet return nil, errors.New("failed to dial to dest: ", err).AtWarning().Base(err) } + if streamSettings.UdpmaskManager != nil { + wrapper, ok := rawConn.(*internet.PacketConnWrapper) + if !ok { + rawConn.Close() + return nil, errors.New("raw is not PacketConnWrapper") + } + + raw := wrapper.Conn + + wrapper.Conn, err = streamSettings.UdpmaskManager.WrapPacketConnClient(raw) + if err != nil { + raw.Close() + return nil, errors.New("mask err").Base(err) + } + } + kcpSettings := streamSettings.ProtocolSettings.(*Config) - header, err := kcpSettings.GetPackerHeader() - if err != nil { - return nil, errors.New("failed to create packet header").Base(err) - } - security, err := kcpSettings.GetSecurity() - if err != nil { - return nil, errors.New("failed to create security").Base(err) - } - reader := &KCPPacketReader{ - Header: header, - Security: security, - } - writer := &KCPPacketWriter{ - Header: header, - Security: security, - Writer: rawConn, - } + reader := &KCPPacketReader{} conv := uint16(atomic.AddUint32(&globalConv, 1)) session := NewConnection(ConnMetadata{ LocalAddr: rawConn.LocalAddr(), RemoteAddr: rawConn.RemoteAddr(), Conversation: conv, - }, writer, rawConn, kcpSettings) + }, rawConn, rawConn, kcpSettings) go fetchInput(ctx, rawConn, reader, session) diff --git a/transport/internet/kcp/io.go b/transport/internet/kcp/io.go index f3b66723..fac9945f 100644 --- a/transport/internet/kcp/io.go +++ b/transport/internet/kcp/io.go @@ -1,48 +1,12 @@ package kcp -import ( - "crypto/cipher" - "crypto/rand" - "io" - - "github.com/amnezia-vpn/amnezia-xray-core/common" - "github.com/amnezia-vpn/amnezia-xray-core/common/buf" - "github.com/amnezia-vpn/amnezia-xray-core/transport/internet" -) - type PacketReader interface { Read([]byte) []Segment } -type PacketWriter interface { - Overhead() int - io.Writer -} - -type KCPPacketReader struct { - Security cipher.AEAD - Header internet.PacketHeader -} +type KCPPacketReader struct{} func (r *KCPPacketReader) Read(b []byte) []Segment { - if r.Header != nil { - if int32(len(b)) <= r.Header.Size() { - return nil - } - b = b[r.Header.Size():] - } - if r.Security != nil { - nonceSize := r.Security.NonceSize() - overhead := r.Security.Overhead() - if len(b) <= nonceSize+overhead { - return nil - } - out, err := r.Security.Open(b[nonceSize:nonceSize], b[:nonceSize], b[nonceSize:], nil) - if err != nil { - return nil - } - b = out - } var result []Segment for len(b) > 0 { seg, x := ReadSegment(b) @@ -54,42 +18,3 @@ func (r *KCPPacketReader) Read(b []byte) []Segment { } return result } - -type KCPPacketWriter struct { - Header internet.PacketHeader - Security cipher.AEAD - Writer io.Writer -} - -func (w *KCPPacketWriter) Overhead() int { - overhead := 0 - if w.Header != nil { - overhead += int(w.Header.Size()) - } - if w.Security != nil { - overhead += w.Security.Overhead() - } - return overhead -} - -func (w *KCPPacketWriter) Write(b []byte) (int, error) { - bb := buf.StackNew() - defer bb.Release() - - if w.Header != nil { - w.Header.Serialize(bb.Extend(w.Header.Size())) - } - if w.Security != nil { - nonceSize := w.Security.NonceSize() - common.Must2(bb.ReadFullFrom(rand.Reader, int32(nonceSize))) - nonce := bb.BytesFrom(int32(-nonceSize)) - - encrypted := bb.Extend(int32(w.Security.Overhead() + len(b))) - w.Security.Seal(encrypted[:0], nonce, b, nil) - } else { - bb.Write(b) - } - - _, err := w.Writer.Write(bb.Bytes()) - return len(b), err -} diff --git a/transport/internet/kcp/io_test.go b/transport/internet/kcp/io_test.go index 2d26f935..27a47d77 100644 --- a/transport/internet/kcp/io_test.go +++ b/transport/internet/kcp/io_test.go @@ -7,9 +7,7 @@ import ( ) func TestKCPPacketReader(t *testing.T) { - reader := KCPPacketReader{ - Security: &SimpleAuthenticator{}, - } + reader := KCPPacketReader{} testCases := []struct { Input []byte diff --git a/transport/internet/kcp/listener.go b/transport/internet/kcp/listener.go index 717aa3cb..6e18e430 100644 --- a/transport/internet/kcp/listener.go +++ b/transport/internet/kcp/listener.go @@ -2,7 +2,6 @@ package kcp import ( "context" - "crypto/cipher" gotls "crypto/tls" "sync" @@ -30,28 +29,14 @@ type Listener struct { tlsConfig *gotls.Config config *Config reader PacketReader - header internet.PacketHeader - security cipher.AEAD addConn internet.ConnHandler } func NewListener(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (*Listener, error) { kcpSettings := streamSettings.ProtocolSettings.(*Config) - header, err := kcpSettings.GetPackerHeader() - if err != nil { - return nil, errors.New("failed to create packet header").Base(err).AtError() - } - security, err := kcpSettings.GetSecurity() - if err != nil { - return nil, errors.New("failed to create security").Base(err).AtError() - } + l := &Listener{ - header: header, - security: security, - reader: &KCPPacketReader{ - Header: header, - Security: security, - }, + reader: &KCPPacketReader{}, sessions: make(map[ConnectionID]*Connection), config: kcpSettings, addConn: addConn, @@ -124,11 +109,7 @@ func (l *Listener) OnReceive(payload *buf.Buffer, src net.Destination) { LocalAddr: localAddr, RemoteAddr: remoteAddr, Conversation: conv, - }, &KCPPacketWriter{ - Header: l.header, - Security: l.security, - Writer: writer, - }, writer, l.config) + }, writer, writer, l.config) var netConn stat.Connection = conn if l.tlsConfig != nil { netConn = tls.Server(conn, l.tlsConfig) diff --git a/transport/internet/reality/reality.go b/transport/internet/reality/reality.go index 5383b0f4..1d36b9a1 100644 --- a/transport/internet/reality/reality.go +++ b/transport/internet/reality/reality.go @@ -24,6 +24,7 @@ import ( "github.com/amnezia-vpn/amnezia-xray-core/common/crypto" "github.com/amnezia-vpn/amnezia-xray-core/common/errors" "github.com/amnezia-vpn/amnezia-xray-core/common/net" + "github.com/amnezia-vpn/amnezia-xray-core/common/utils" "github.com/amnezia-vpn/amnezia-xray-core/core" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/tls" "github.com/cloudflare/circl/sign/mldsa/mldsa65" @@ -222,7 +223,7 @@ func UClient(c net.Conn, config *Config, ctx context.Context, dest net.Destinati if req == nil { return } - req.Header.Set("User-Agent", fingerprint.Client) // TODO: User-Agent map + req.Header.Set("User-Agent", utils.ChromeUA) if first && config.Show { fmt.Printf("REALITY localAddr: %v\treq.UserAgent(): %v\n", localAddr, req.UserAgent()) } diff --git a/transport/internet/splithttp/browser_client.go b/transport/internet/splithttp/browser_client.go index 254fa856..b61b569e 100644 --- a/transport/internet/splithttp/browser_client.go +++ b/transport/internet/splithttp/browser_client.go @@ -19,12 +19,35 @@ func (c *BrowserDialerClient) IsClosed() bool { panic("not implemented yet") } -func (c *BrowserDialerClient) OpenStream(ctx context.Context, url string, body io.Reader, uploadOnly bool) (io.ReadCloser, net.Addr, net.Addr, error) { +func (c *BrowserDialerClient) OpenStream(ctx context.Context, url string, _ string, body io.Reader, uploadOnly bool) (io.ReadCloser, net.Addr, net.Addr, error) { if body != nil { return nil, nil, nil, errors.New("bidirectional streaming for browser dialer not implemented yet") } - conn, err := browser_dialer.DialGet(url, c.transportConfig.GetRequestHeader(url)) + header := c.transportConfig.GetRequestHeader() + length := int(c.transportConfig.GetNormalizedXPaddingBytes().rand()) + config := XPaddingConfig{Length: length} + + if c.transportConfig.XPaddingObfsMode { + config.Placement = XPaddingPlacement{ + Placement: c.transportConfig.XPaddingPlacement, + Key: c.transportConfig.XPaddingKey, + Header: c.transportConfig.XPaddingHeader, + RawURL: url, + } + config.Method = PaddingMethod(c.transportConfig.XPaddingMethod) + } else { + config.Placement = XPaddingPlacement{ + Placement: PlacementQueryInHeader, + Key: "x_padding", + Header: "Referer", + RawURL: url, + } + } + + c.transportConfig.ApplyXPaddingToHeader(header, config) + + conn, err := browser_dialer.DialGet(url, header) dummyAddr := &net.IPAddr{} if err != nil { return nil, dummyAddr, dummyAddr, err @@ -33,13 +56,36 @@ func (c *BrowserDialerClient) OpenStream(ctx context.Context, url string, body i return websocket.NewConnection(conn, dummyAddr, nil, 0), conn.RemoteAddr(), conn.LocalAddr(), nil } -func (c *BrowserDialerClient) PostPacket(ctx context.Context, url string, body io.Reader, contentLength int64) error { +func (c *BrowserDialerClient) PostPacket(ctx context.Context, url string, _ string, _ string, body io.Reader, contentLength int64) error { bytes, err := io.ReadAll(body) if err != nil { return err } - err = browser_dialer.DialPost(url, c.transportConfig.GetRequestHeader(url), bytes) + header := c.transportConfig.GetRequestHeader() + length := int(c.transportConfig.GetNormalizedXPaddingBytes().rand()) + config := XPaddingConfig{Length: length} + + if c.transportConfig.XPaddingObfsMode { + config.Placement = XPaddingPlacement{ + Placement: c.transportConfig.XPaddingPlacement, + Key: c.transportConfig.XPaddingKey, + Header: c.transportConfig.XPaddingHeader, + RawURL: url, + } + config.Method = PaddingMethod(c.transportConfig.XPaddingMethod) + } else { + config.Placement = XPaddingPlacement{ + Placement: PlacementQueryInHeader, + Key: "x_padding", + Header: "Referer", + RawURL: url, + } + } + + c.transportConfig.ApplyXPaddingToHeader(header, config) + + err = browser_dialer.DialPost(url, header, bytes) if err != nil { return err } diff --git a/transport/internet/splithttp/client.go b/transport/internet/splithttp/client.go index 6163160a..d8f31d1b 100644 --- a/transport/internet/splithttp/client.go +++ b/transport/internet/splithttp/client.go @@ -3,6 +3,7 @@ package splithttp import ( "bytes" "context" + "encoding/base64" "fmt" "io" "net/http" @@ -19,11 +20,11 @@ import ( type DialerClient interface { IsClosed() bool - // ctx, url, body, uploadOnly - OpenStream(context.Context, string, io.Reader, bool) (io.ReadCloser, net.Addr, net.Addr, error) + // ctx, url, sessionId, body, uploadOnly + OpenStream(context.Context, string, string, io.Reader, bool) (io.ReadCloser, net.Addr, net.Addr, error) - // ctx, url, body, contentLength - PostPacket(context.Context, string, io.Reader, int64) error + // ctx, url, sessionId, seqStr, body, contentLength + PostPacket(context.Context, string, string, string, io.Reader, int64) error } // implements splithttp.DialerClient in terms of direct network connections @@ -41,7 +42,7 @@ func (c *DefaultDialerClient) IsClosed() bool { return c.closed } -func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, body io.Reader, uploadOnly bool) (wrc io.ReadCloser, remoteAddr, localAddr net.Addr, err error) { +func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, sessionId string, body io.Reader, uploadOnly bool) (wrc io.ReadCloser, remoteAddr, localAddr net.Addr, err error) { // this is done when the TCP/UDP connection to the server was established, // and we can unblock the Dial function and print correct net addresses in // logs @@ -56,11 +57,34 @@ func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, body i method := "GET" // stream-down if body != nil { - method = "POST" // stream-up/one + method = c.transportConfig.GetNormalizedUplinkHTTPMethod() // stream-up/one } req, _ := http.NewRequestWithContext(context.WithoutCancel(ctx), method, url, body) - req.Header = c.transportConfig.GetRequestHeader(url) - if method == "POST" && !c.transportConfig.NoGRPCHeader { + req.Header = c.transportConfig.GetRequestHeader() + length := int(c.transportConfig.GetNormalizedXPaddingBytes().rand()) + config := XPaddingConfig{Length: length} + + if c.transportConfig.XPaddingObfsMode { + config.Placement = XPaddingPlacement{ + Placement: c.transportConfig.XPaddingPlacement, + Key: c.transportConfig.XPaddingKey, + Header: c.transportConfig.XPaddingHeader, + RawURL: url, + } + config.Method = PaddingMethod(c.transportConfig.XPaddingMethod) + } else { + config.Placement = XPaddingPlacement{ + Placement: PlacementQueryInHeader, + Key: "x_padding", + Header: "Referer", + RawURL: url, + } + } + + c.transportConfig.ApplyXPaddingToRequest(req, config) + c.transportConfig.ApplyMetaToRequest(req, sessionId, "") + + if method == c.transportConfig.GetNormalizedUplinkHTTPMethod() && !c.transportConfig.NoGRPCHeader { req.Header.Set("Content-Type", "application/grpc") } @@ -92,13 +116,83 @@ func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, body i return } -func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, body io.Reader, contentLength int64) error { - req, err := http.NewRequestWithContext(context.WithoutCancel(ctx), "POST", url, body) +func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessionId string, seqStr string, body io.Reader, contentLength int64) error { + var encodedData string + dataPlacement := c.transportConfig.GetNormalizedUplinkDataPlacement() + + if dataPlacement != PlacementBody { + data, err := io.ReadAll(body) + if err != nil { + return err + } + encodedData = base64.RawURLEncoding.EncodeToString(data) + body = nil + contentLength = 0 + } + + method := c.transportConfig.GetNormalizedUplinkHTTPMethod() + req, err := http.NewRequestWithContext(context.WithoutCancel(ctx), method, url, body) if err != nil { return err } req.ContentLength = contentLength - req.Header = c.transportConfig.GetRequestHeader(url) + req.Header = c.transportConfig.GetRequestHeader() + + if dataPlacement != PlacementBody { + key := c.transportConfig.UplinkDataKey + chunkSize := int(c.transportConfig.UplinkChunkSize) + + switch dataPlacement { + case PlacementHeader: + for i := 0; i < len(encodedData); i += chunkSize { + end := i + chunkSize + if end > len(encodedData) { + end = len(encodedData) + } + chunk := encodedData[i:end] + headerKey := fmt.Sprintf("%s-%d", key, i/chunkSize) + req.Header.Set(headerKey, chunk) + } + + req.Header.Set(key+"-Length", fmt.Sprintf("%d", len(encodedData))) + req.Header.Set(key+"-Upstream", "1") + case PlacementCookie: + for i := 0; i < len(encodedData); i += chunkSize { + end := i + chunkSize + if end > len(encodedData) { + end = len(encodedData) + } + chunk := encodedData[i:end] + cookieName := fmt.Sprintf("%s_%d", key, i/chunkSize) + req.AddCookie(&http.Cookie{Name: cookieName, Value: chunk}) + } + + req.AddCookie(&http.Cookie{Name: key + "_upstream", Value: "1"}) + } + } + + length := int(c.transportConfig.GetNormalizedXPaddingBytes().rand()) + config := XPaddingConfig{Length: length} + + if c.transportConfig.XPaddingObfsMode { + config.Placement = XPaddingPlacement{ + Placement: c.transportConfig.XPaddingPlacement, + Key: c.transportConfig.XPaddingKey, + Header: c.transportConfig.XPaddingHeader, + RawURL: url, + } + config.Method = PaddingMethod(c.transportConfig.XPaddingMethod) + } else { + config.Placement = XPaddingPlacement{ + Placement: PlacementQueryInHeader, + Key: "x_padding", + Header: "Referer", + RawURL: url, + } + } + + c.transportConfig.ApplyXPaddingToRequest(req, config) + c.transportConfig.ApplyMetaToRequest(req, sessionId, seqStr) if c.httpVersion != "1.1" { resp, err := c.client.Do(req) diff --git a/transport/internet/splithttp/common.go b/transport/internet/splithttp/common.go new file mode 100644 index 00000000..20596180 --- /dev/null +++ b/transport/internet/splithttp/common.go @@ -0,0 +1,10 @@ +package splithttp + +const ( + PlacementQueryInHeader = "queryInHeader" + PlacementCookie = "cookie" + PlacementHeader = "header" + PlacementQuery = "query" + PlacementPath = "path" + PlacementBody = "body" +) diff --git a/transport/internet/splithttp/config.go b/transport/internet/splithttp/config.go index 2323e81a..deaa57f2 100644 --- a/transport/internet/splithttp/config.go +++ b/transport/internet/splithttp/config.go @@ -2,11 +2,11 @@ package splithttp import ( "net/http" - "net/url" "strings" "github.com/amnezia-vpn/amnezia-xray-core/common" "github.com/amnezia-vpn/amnezia-xray-core/common/crypto" + "github.com/amnezia-vpn/amnezia-xray-core/common/utils" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet" ) @@ -43,41 +43,30 @@ func (c *Config) GetNormalizedQuery() string { return query } -func (c *Config) GetRequestHeader(rawURL string) http.Header { +func (c *Config) GetRequestHeader() http.Header { header := http.Header{} for k, v := range c.Headers { header.Add(k, v) } - - u, _ := url.Parse(rawURL) - // https://www.rfc-editor.org/rfc/rfc7541.html#appendix-B - // h2's HPACK Header Compression feature employs a huffman encoding using a static table. - // 'X' is assigned an 8 bit code, so HPACK compression won't change actual padding length on the wire. - // https://www.rfc-editor.org/rfc/rfc9204.html#section-4.1.2-2 - // h3's similar QPACK feature uses the same huffman table. - u.RawQuery = "x_padding=" + strings.Repeat("X", int(c.GetNormalizedXPaddingBytes().rand())) - header.Set("Referer", u.String()) - + if header.Get("User-Agent") == "" { + header.Set("User-Agent", utils.ChromeUA) + } return header } func (c *Config) WriteResponseHeader(writer http.ResponseWriter) { // CORS headers for the browser dialer writer.Header().Set("Access-Control-Allow-Origin", "*") - writer.Header().Set("Access-Control-Allow-Methods", "GET, POST") + writer.Header().Set("Access-Control-Allow-Methods", "*") // writer.Header().Set("X-Version", core.Version()) - writer.Header().Set("X-Padding", strings.Repeat("X", int(c.GetNormalizedXPaddingBytes().rand()))) } -func (c *Config) GetNormalizedXPaddingBytes() RangeConfig { - if c.XPaddingBytes == nil || c.XPaddingBytes.To == 0 { - return RangeConfig{ - From: 100, - To: 1000, - } +func (c *Config) GetNormalizedUplinkHTTPMethod() string { + if c.UplinkHTTPMethod == "" { + return "POST" } - return *c.XPaddingBytes + return c.UplinkHTTPMethod } func (c *Config) GetNormalizedScMaxEachPostBytes() RangeConfig { @@ -121,6 +110,134 @@ func (c *Config) GetNormalizedScStreamUpServerSecs() RangeConfig { return *c.ScStreamUpServerSecs } +func (c *Config) GetNormalizedSessionPlacement() string { + if c.SessionPlacement == "" { + return PlacementPath + } + return c.SessionPlacement +} + +func (c *Config) GetNormalizedSeqPlacement() string { + if c.SeqPlacement == "" { + return PlacementPath + } + return c.SeqPlacement +} + +func (c *Config) GetNormalizedUplinkDataPlacement() string { + if c.UplinkDataPlacement == "" { + return PlacementBody + } + return c.UplinkDataPlacement +} + +func (c *Config) GetNormalizedSessionKey() string { + if c.SessionKey != "" { + return c.SessionKey + } + switch c.GetNormalizedSessionPlacement() { + case PlacementHeader: + return "X-Session" + case PlacementCookie, PlacementQuery: + return "x_session" + default: + return "" + } +} + +func (c *Config) GetNormalizedSeqKey() string { + if c.SeqKey != "" { + return c.SeqKey + } + switch c.GetNormalizedSeqPlacement() { + case PlacementHeader: + return "X-Seq" + case PlacementCookie, PlacementQuery: + return "x_seq" + default: + return "" + } +} + +func (c *Config) ApplyMetaToRequest(req *http.Request, sessionId string, seqStr string) { + sessionPlacement := c.GetNormalizedSessionPlacement() + seqPlacement := c.GetNormalizedSeqPlacement() + sessionKey := c.GetNormalizedSessionKey() + seqKey := c.GetNormalizedSeqKey() + + if sessionId != "" { + switch sessionPlacement { + case PlacementPath: + req.URL.Path = appendToPath(req.URL.Path, sessionId) + case PlacementQuery: + q := req.URL.Query() + q.Set(sessionKey, sessionId) + req.URL.RawQuery = q.Encode() + case PlacementHeader: + req.Header.Set(sessionKey, sessionId) + case PlacementCookie: + req.AddCookie(&http.Cookie{Name: sessionKey, Value: sessionId}) + } + } + + if seqStr != "" { + switch seqPlacement { + case PlacementPath: + req.URL.Path = appendToPath(req.URL.Path, seqStr) + case PlacementQuery: + q := req.URL.Query() + q.Set(seqKey, seqStr) + req.URL.RawQuery = q.Encode() + case PlacementHeader: + req.Header.Set(seqKey, seqStr) + case PlacementCookie: + req.AddCookie(&http.Cookie{Name: seqKey, Value: seqStr}) + } + } +} + +func (c *Config) ExtractMetaFromRequest(req *http.Request, path string) (sessionId string, seqStr string) { + sessionPlacement := c.GetNormalizedSessionPlacement() + seqPlacement := c.GetNormalizedSeqPlacement() + sessionKey := c.GetNormalizedSessionKey() + seqKey := c.GetNormalizedSeqKey() + + if sessionPlacement == PlacementPath && seqPlacement == PlacementPath { + subpath := strings.Split(req.URL.Path[len(path):], "/") + if len(subpath) > 0 { + sessionId = subpath[0] + } + if len(subpath) > 1 { + seqStr = subpath[1] + } + return sessionId, seqStr + } + + switch sessionPlacement { + case PlacementQuery: + sessionId = req.URL.Query().Get(sessionKey) + case PlacementHeader: + sessionId = req.Header.Get(sessionKey) + case PlacementCookie: + if cookie, e := req.Cookie(sessionKey); e == nil { + sessionId = cookie.Value + } + } + + switch seqPlacement { + case PlacementQuery: + seqStr = req.URL.Query().Get(seqKey) + case PlacementHeader: + seqStr = req.Header.Get(seqKey) + case PlacementCookie: + if cookie, e := req.Cookie(seqKey); e == nil { + seqStr = cookie.Value + } + } + + return sessionId, seqStr +} + func (m *XmuxConfig) GetNormalizedMaxConcurrency() RangeConfig { if m.MaxConcurrency == nil { return RangeConfig{ @@ -185,3 +302,10 @@ func init() { func (c RangeConfig) rand() int32 { return int32(crypto.RandBetween(int64(c.From), int64(c.To))) } + +func appendToPath(path, value string) string { + if strings.HasSuffix(path, "/") { + return path + value + } + return path + "/" + value +} diff --git a/transport/internet/splithttp/config.pb.go b/transport/internet/splithttp/config.pb.go index eba46806..7ec99f48 100644 --- a/transport/internet/splithttp/config.pb.go +++ b/transport/internet/splithttp/config.pb.go @@ -173,6 +173,19 @@ type Config struct { ScStreamUpServerSecs *RangeConfig `protobuf:"bytes,11,opt,name=scStreamUpServerSecs,proto3" json:"scStreamUpServerSecs,omitempty"` Xmux *XmuxConfig `protobuf:"bytes,12,opt,name=xmux,proto3" json:"xmux,omitempty"` DownloadSettings *internet.StreamConfig `protobuf:"bytes,13,opt,name=downloadSettings,proto3" json:"downloadSettings,omitempty"` + XPaddingObfsMode bool `protobuf:"varint,14,opt,name=xPaddingObfsMode,proto3" json:"xPaddingObfsMode,omitempty"` + XPaddingKey string `protobuf:"bytes,15,opt,name=xPaddingKey,proto3" json:"xPaddingKey,omitempty"` + XPaddingHeader string `protobuf:"bytes,16,opt,name=xPaddingHeader,proto3" json:"xPaddingHeader,omitempty"` + XPaddingPlacement string `protobuf:"bytes,17,opt,name=xPaddingPlacement,proto3" json:"xPaddingPlacement,omitempty"` + XPaddingMethod string `protobuf:"bytes,18,opt,name=xPaddingMethod,proto3" json:"xPaddingMethod,omitempty"` + UplinkHTTPMethod string `protobuf:"bytes,19,opt,name=uplinkHTTPMethod,proto3" json:"uplinkHTTPMethod,omitempty"` + SessionPlacement string `protobuf:"bytes,20,opt,name=sessionPlacement,proto3" json:"sessionPlacement,omitempty"` + SessionKey string `protobuf:"bytes,21,opt,name=sessionKey,proto3" json:"sessionKey,omitempty"` + SeqPlacement string `protobuf:"bytes,22,opt,name=seqPlacement,proto3" json:"seqPlacement,omitempty"` + SeqKey string `protobuf:"bytes,23,opt,name=seqKey,proto3" json:"seqKey,omitempty"` + UplinkDataPlacement string `protobuf:"bytes,24,opt,name=uplinkDataPlacement,proto3" json:"uplinkDataPlacement,omitempty"` + UplinkDataKey string `protobuf:"bytes,25,opt,name=uplinkDataKey,proto3" json:"uplinkDataKey,omitempty"` + UplinkChunkSize uint32 `protobuf:"varint,26,opt,name=uplinkChunkSize,proto3" json:"uplinkChunkSize,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -298,6 +311,97 @@ func (x *Config) GetDownloadSettings() *internet.StreamConfig { return nil } +func (x *Config) GetXPaddingObfsMode() bool { + if x != nil { + return x.XPaddingObfsMode + } + return false +} + +func (x *Config) GetXPaddingKey() string { + if x != nil { + return x.XPaddingKey + } + return "" +} + +func (x *Config) GetXPaddingHeader() string { + if x != nil { + return x.XPaddingHeader + } + return "" +} + +func (x *Config) GetXPaddingPlacement() string { + if x != nil { + return x.XPaddingPlacement + } + return "" +} + +func (x *Config) GetXPaddingMethod() string { + if x != nil { + return x.XPaddingMethod + } + return "" +} + +func (x *Config) GetUplinkHTTPMethod() string { + if x != nil { + return x.UplinkHTTPMethod + } + return "" +} + +func (x *Config) GetSessionPlacement() string { + if x != nil { + return x.SessionPlacement + } + return "" +} + +func (x *Config) GetSessionKey() string { + if x != nil { + return x.SessionKey + } + return "" +} + +func (x *Config) GetSeqPlacement() string { + if x != nil { + return x.SeqPlacement + } + return "" +} + +func (x *Config) GetSeqKey() string { + if x != nil { + return x.SeqKey + } + return "" +} + +func (x *Config) GetUplinkDataPlacement() string { + if x != nil { + return x.UplinkDataPlacement + } + return "" +} + +func (x *Config) GetUplinkDataKey() string { + if x != nil { + return x.UplinkDataKey + } + return "" +} + +func (x *Config) GetUplinkChunkSize() uint32 { + if x != nil { + return x.UplinkChunkSize + } + return 0 +} + var File_transport_internet_splithttp_config_proto protoreflect.FileDescriptor const file_transport_internet_splithttp_config_proto_rawDesc = "" + @@ -313,7 +417,8 @@ const file_transport_internet_splithttp_config_proto_rawDesc = "" + "\x0ecMaxReuseTimes\x18\x03 \x01(\v2..xray.transport.internet.splithttp.RangeConfigR\x0ecMaxReuseTimes\x12Z\n" + "\x10hMaxRequestTimes\x18\x04 \x01(\v2..xray.transport.internet.splithttp.RangeConfigR\x10hMaxRequestTimes\x12Z\n" + "\x10hMaxReusableSecs\x18\x05 \x01(\v2..xray.transport.internet.splithttp.RangeConfigR\x10hMaxReusableSecs\x12*\n" + - "\x10hKeepAlivePeriod\x18\x06 \x01(\x03R\x10hKeepAlivePeriod\"\xdc\x06\n" + + "\x10hKeepAlivePeriod\x18\x06 \x01(\x03R\x10hKeepAlivePeriod\"\xde\n" + + "\n" + "\x06Config\x12\x12\n" + "\x04host\x18\x01 \x01(\tR\x04host\x12\x12\n" + "\x04path\x18\x02 \x01(\tR\x04path\x12\x12\n" + @@ -328,7 +433,22 @@ const file_transport_internet_splithttp_config_proto_rawDesc = "" + " \x01(\x03R\x12scMaxBufferedPosts\x12b\n" + "\x14scStreamUpServerSecs\x18\v \x01(\v2..xray.transport.internet.splithttp.RangeConfigR\x14scStreamUpServerSecs\x12A\n" + "\x04xmux\x18\f \x01(\v2-.xray.transport.internet.splithttp.XmuxConfigR\x04xmux\x12Q\n" + - "\x10downloadSettings\x18\r \x01(\v2%.xray.transport.internet.StreamConfigR\x10downloadSettings\x1a:\n" + + "\x10downloadSettings\x18\r \x01(\v2%.xray.transport.internet.StreamConfigR\x10downloadSettings\x12*\n" + + "\x10xPaddingObfsMode\x18\x0e \x01(\bR\x10xPaddingObfsMode\x12 \n" + + "\vxPaddingKey\x18\x0f \x01(\tR\vxPaddingKey\x12&\n" + + "\x0exPaddingHeader\x18\x10 \x01(\tR\x0exPaddingHeader\x12,\n" + + "\x11xPaddingPlacement\x18\x11 \x01(\tR\x11xPaddingPlacement\x12&\n" + + "\x0exPaddingMethod\x18\x12 \x01(\tR\x0exPaddingMethod\x12*\n" + + "\x10uplinkHTTPMethod\x18\x13 \x01(\tR\x10uplinkHTTPMethod\x12*\n" + + "\x10sessionPlacement\x18\x14 \x01(\tR\x10sessionPlacement\x12\x1e\n" + + "\n" + + "sessionKey\x18\x15 \x01(\tR\n" + + "sessionKey\x12\"\n" + + "\fseqPlacement\x18\x16 \x01(\tR\fseqPlacement\x12\x16\n" + + "\x06seqKey\x18\x17 \x01(\tR\x06seqKey\x120\n" + + "\x13uplinkDataPlacement\x18\x18 \x01(\tR\x13uplinkDataPlacement\x12$\n" + + "\ruplinkDataKey\x18\x19 \x01(\tR\ruplinkDataKey\x12(\n" + + "\x0fuplinkChunkSize\x18\x1a \x01(\rR\x0fuplinkChunkSize\x1a:\n" + "\fHeadersEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01B\x94\x01\n" + diff --git a/transport/internet/splithttp/config.proto b/transport/internet/splithttp/config.proto index 0bf2b2ab..4f9e6876 100644 --- a/transport/internet/splithttp/config.proto +++ b/transport/internet/splithttp/config.proto @@ -36,4 +36,17 @@ message Config { RangeConfig scStreamUpServerSecs = 11; XmuxConfig xmux = 12; xray.transport.internet.StreamConfig downloadSettings = 13; + bool xPaddingObfsMode = 14; + string xPaddingKey = 15; + string xPaddingHeader = 16; + string xPaddingPlacement = 17; + string xPaddingMethod = 18; + string uplinkHTTPMethod = 19; + string sessionPlacement = 20; + string sessionKey = 21; + string seqPlacement = 22; + string seqKey = 23; + string uplinkDataPlacement = 24; + string uplinkDataKey = 25; + uint32 uplinkChunkSize = 26; } diff --git a/transport/internet/splithttp/dialer.go b/transport/internet/splithttp/dialer.go index f940a2b2..b0f7e9d9 100644 --- a/transport/internet/splithttp/dialer.go +++ b/transport/internet/splithttp/dialer.go @@ -272,8 +272,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me requestURL.Host = dest.Address.String() } - sessionIdUuid := uuid.New() - requestURL.Path = transportConfiguration.GetNormalizedPath() + sessionIdUuid.String() + requestURL.Path = transportConfiguration.GetNormalizedPath() requestURL.RawQuery = transportConfiguration.GetNormalizedQuery() httpClient, xmuxClient := getHTTPClient(ctx, dest, streamSettings) @@ -289,6 +288,12 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me } } + sessionId := "" + if mode != "stream-one" { + sessionIdUuid := uuid.New() + sessionId = sessionIdUuid.String() + } + errors.LogInfo(ctx, fmt.Sprintf("XHTTP is dialing to %s, mode %s, HTTP version %s, host %s", dest, mode, httpVersion, requestURL.Host)) requestURL2 := requestURL @@ -327,7 +332,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me if requestURL2.Host == "" { requestURL2.Host = dest2.Address.String() } - requestURL2.Path = config2.GetNormalizedPath() + sessionIdUuid.String() + requestURL2.Path = config2.GetNormalizedPath() requestURL2.RawQuery = config2.GetNormalizedQuery() httpClient2, xmuxClient2 = getHTTPClient(ctx, dest2, memory2) errors.LogInfo(ctx, fmt.Sprintf("XHTTP is downloading from %s, mode %s, HTTP version %s, host %s", dest2, "stream-down", httpVersion2, requestURL2.Host)) @@ -363,7 +368,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me if xmuxClient != nil { xmuxClient.LeftRequests.Add(-1) } - conn.reader, conn.remoteAddr, conn.localAddr, err = httpClient.OpenStream(ctx, requestURL.String(), reader, false) + conn.reader, conn.remoteAddr, conn.localAddr, err = httpClient.OpenStream(ctx, requestURL.String(), sessionId, reader, false) if err != nil { // browser dialer only return nil, err } @@ -372,7 +377,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me if xmuxClient2 != nil { xmuxClient2.LeftRequests.Add(-1) } - conn.reader, conn.remoteAddr, conn.localAddr, err = httpClient2.OpenStream(ctx, requestURL2.String(), nil, false) + conn.reader, conn.remoteAddr, conn.localAddr, err = httpClient2.OpenStream(ctx, requestURL2.String(), sessionId, nil, false) if err != nil { // browser dialer only return nil, err } @@ -381,7 +386,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me if xmuxClient != nil { xmuxClient.LeftRequests.Add(-1) } - _, _, _, err = httpClient.OpenStream(ctx, requestURL.String(), reader, true) + _, _, _, err = httpClient.OpenStream(ctx, requestURL.String(), sessionId, reader, true) if err != nil { // browser dialer only return nil, err } @@ -423,8 +428,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me // this intentionally makes a shallow-copy of the struct so we // can reassign Path (potentially concurrently) url := requestURL - url.Path += "/" + strconv.FormatInt(seq, 10) - + seqStr := strconv.FormatInt(seq, 10) seq += 1 if scMinPostsIntervalMs.From > 0 { @@ -450,6 +454,8 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me err := httpClient.PostPacket( ctx, url.String(), + sessionId, + seqStr, &buf.MultiBufferContainer{MultiBuffer: chunk}, int64(chunk.Len()), ) diff --git a/transport/internet/splithttp/hub.go b/transport/internet/splithttp/hub.go index fbe6fa58..e67e0637 100644 --- a/transport/internet/splithttp/hub.go +++ b/transport/internet/splithttp/hub.go @@ -4,9 +4,10 @@ import ( "bytes" "context" gotls "crypto/tls" + "encoding/base64" + "fmt" "io" "net/http" - "net/url" "strconv" "strings" "sync" @@ -100,6 +101,24 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req } h.config.WriteResponseHeader(writer) + length := int(h.config.GetNormalizedXPaddingBytes().rand()) + config := XPaddingConfig{Length: length} + + if h.config.XPaddingObfsMode { + config.Placement = XPaddingPlacement{ + Placement: h.config.XPaddingPlacement, + Key: h.config.XPaddingKey, + Header: h.config.XPaddingHeader, + } + config.Method = PaddingMethod(h.config.XPaddingMethod) + } else { + config.Placement = XPaddingPlacement{ + Placement: PlacementHeader, + Header: "X-Padding", + } + } + + h.config.ApplyXPaddingToHeader(writer.Header(), config) /* clientVer := []int{0, 0, 0} @@ -110,29 +129,15 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req */ validRange := h.config.GetNormalizedXPaddingBytes() - paddingLength := 0 + paddingValue, paddingPlacement := h.config.ExtractXPaddingFromRequest(request, h.config.XPaddingObfsMode) - referrer := request.Header.Get("Referer") - if referrer != "" { - if referrerURL, err := url.Parse(referrer); err == nil { - // Browser dialer cannot control the host part of referrer header, so only check the query - paddingLength = len(referrerURL.Query().Get("x_padding")) - } - } else { - paddingLength = len(request.URL.Query().Get("x_padding")) - } - - if int32(paddingLength) < validRange.From || int32(paddingLength) > validRange.To { - errors.LogInfo(context.Background(), "invalid x_padding length:", int32(paddingLength)) + if !h.config.IsPaddingValid(paddingValue, validRange.From, validRange.To, PaddingMethod(h.config.XPaddingMethod)) { + errors.LogInfo(context.Background(), "invalid padding ("+paddingPlacement+") length:", int32(len(paddingValue))) writer.WriteHeader(http.StatusBadRequest) return } - sessionId := "" - subpath := strings.Split(request.URL.Path[len(h.path):], "/") - if len(subpath) > 0 { - sessionId = subpath[0] - } + sessionId, seqStr := h.config.ExtractMetaFromRequest(request, h.path) if sessionId == "" && h.config.Mode != "" && h.config.Mode != "auto" && h.config.Mode != "stream-one" && h.config.Mode != "stream-up" { errors.LogInfo(context.Background(), "stream-one mode is not allowed") @@ -178,14 +183,29 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req currentSession = h.upsertSession(sessionId) } scMaxEachPostBytes := int(h.ln.config.GetNormalizedScMaxEachPostBytes().To) + uplinkHTTPMethod := h.config.GetNormalizedUplinkHTTPMethod() + isUplinkRequest := false - if request.Method == "POST" && sessionId != "" { // stream-up, packet-up - seq := "" - if len(subpath) > 1 { - seq = subpath[1] + if uplinkHTTPMethod != "GET" && request.Method == uplinkHTTPMethod { + isUplinkRequest = true + } + + uplinkDataPlacement := h.config.GetNormalizedUplinkDataPlacement() + uplinkDataKey := h.config.UplinkDataKey + + switch uplinkDataPlacement { + case PlacementHeader: + if request.Header.Get(uplinkDataKey+"-Upstream") == "1" { + isUplinkRequest = true } + case PlacementCookie: + if c, _ := request.Cookie(uplinkDataKey + "_upstream"); c != nil && c.Value == "1" { + isUplinkRequest = true + } + } - if seq == "" { + if isUplinkRequest && sessionId != "" { // stream-up, packet-up + if seqStr == "" { if h.config.Mode != "" && h.config.Mode != "auto" && h.config.Mode != "stream-up" { errors.LogInfo(context.Background(), "stream-up mode is not allowed") writer.WriteHeader(http.StatusBadRequest) @@ -207,6 +227,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req writer.Header().Set("Cache-Control", "no-store") writer.WriteHeader(http.StatusOK) scStreamUpServerSecs := h.config.GetNormalizedScStreamUpServerSecs() + referrer := request.Header.Get("Referer") if referrer != "" && scStreamUpServerSecs.To > 0 { go func() { for { @@ -233,7 +254,62 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req return } - payload, err := io.ReadAll(io.LimitReader(request.Body, int64(scMaxEachPostBytes)+1)) + var payload []byte + + if uplinkDataPlacement != PlacementBody { + var encodedStr string + switch uplinkDataPlacement { + case PlacementHeader: + dataLenStr := request.Header.Get(uplinkDataKey + "-Length") + + if dataLenStr != "" { + dataLen, _ := strconv.Atoi(dataLenStr) + var chunks []string + i := 0 + + for { + chunk := request.Header.Get(fmt.Sprintf("%s-%d", uplinkDataKey, i)) + if chunk == "" { + break + } + chunks = append(chunks, chunk) + i++ + } + + encodedStr = strings.Join(chunks, "") + if len(encodedStr) != dataLen { + encodedStr = "" + } + } + case PlacementCookie: + var chunks []string + i := 0 + + for { + cookieName := fmt.Sprintf("%s_%d", uplinkDataKey, i) + if c, _ := request.Cookie(cookieName); c != nil { + chunks = append(chunks, c.Value) + i++ + } else { + break + } + } + + if len(chunks) > 0 { + encodedStr = strings.Join(chunks, "") + } + } + + if encodedStr != "" { + payload, err = base64.RawURLEncoding.DecodeString(encodedStr) + } else { + errors.LogInfoInner(context.Background(), err, "failed to extract data from key "+uplinkDataKey+" placed in "+uplinkDataPlacement) + writer.WriteHeader(http.StatusInternalServerError) + return + } + } else { + payload, err = io.ReadAll(io.LimitReader(request.Body, int64(scMaxEachPostBytes)+1)) + } if len(payload) > scMaxEachPostBytes { errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request size exceed it. Adjust scMaxEachPostBytes on the server to be at least as large as client.") @@ -247,7 +323,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req return } - seqInt, err := strconv.ParseUint(seq, 10, 64) + seq, err := strconv.ParseUint(seqStr, 10, 64) if err != nil { errors.LogInfoInner(context.Background(), err, "failed to upload (ParseUint)") writer.WriteHeader(http.StatusInternalServerError) @@ -256,7 +332,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req err = currentSession.uploadQueue.Push(Packet{ Payload: payload, - Seq: seqInt, + Seq: seq, }) if err != nil { diff --git a/transport/internet/splithttp/xpadding.go b/transport/internet/splithttp/xpadding.go new file mode 100644 index 00000000..ce224369 --- /dev/null +++ b/transport/internet/splithttp/xpadding.go @@ -0,0 +1,307 @@ +package splithttp + +import ( + "crypto/rand" + "math" + "net/http" + "net/url" + "strings" + + "golang.org/x/net/http2/hpack" +) + +type PaddingMethod string + +const ( + PaddingMethodRepeatX PaddingMethod = "repeat-x" + PaddingMethodTokenish PaddingMethod = "tokenish" +) + +const charsetBase62 = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + +// Huffman encoding gives ~20% size reduction for base62 sequences +const avgHuffmanBytesPerCharBase62 = 0.8 + +const validationTolerance = 2 + +type XPaddingPlacement struct { + Placement string + Key string + Header string + RawURL string +} + +type XPaddingConfig struct { + Length int + Placement XPaddingPlacement + Method PaddingMethod +} + +func randStringFromCharset(n int, charset string) (string, bool) { + if n <= 0 || len(charset) == 0 { + return "", false + } + + m := len(charset) + limit := byte(256 - (256 % m)) + + result := make([]byte, n) + i := 0 + + buf := make([]byte, 256) + for i < n { + if _, err := rand.Read(buf); err != nil { + return "", false + } + for _, rb := range buf { + if rb >= limit { + continue + } + result[i] = charset[int(rb)%m] + i++ + if i == n { + break + } + } + } + + return string(result), true +} + +func absInt(x int) int { + if x < 0 { + return -x + } + return x +} + +func GenerateTokenishPaddingBase62(targetHuffmanBytes int) string { + n := int(math.Ceil(float64(targetHuffmanBytes) / avgHuffmanBytesPerCharBase62)) + if n < 1 { + n = 1 + } + + randBase62Str, ok := randStringFromCharset(n, charsetBase62) + if !ok { + return "" + } + + const maxIter = 150 + adjustChar := byte('X') + + // Adjust until close enough + for iter := 0; iter < maxIter; iter++ { + currentLength := int(hpack.HuffmanEncodeLength(randBase62Str)) + diff := currentLength - targetHuffmanBytes + + if absInt(diff) <= validationTolerance { + return randBase62Str + } + + if diff < 0 { + // Too small -> append padding char(s) + randBase62Str += string(adjustChar) + + // Avoid a long run of identical chars + if adjustChar == 'X' { + adjustChar = 'Z' + } else { + adjustChar = 'X' + } + } else { + // Too big -> remove from the end + if len(randBase62Str) <= 1 { + return randBase62Str + } + randBase62Str = randBase62Str[:len(randBase62Str)-1] + } + } + + return randBase62Str +} + +func GeneratePadding(method PaddingMethod, length int) string { + if length <= 0 { + return "" + } + + // https://www.rfc-editor.org/rfc/rfc7541.html#appendix-B + // h2's HPACK Header Compression feature employs a huffman encoding using a static table. + // 'X' and 'Z' are assigned an 8 bit code, so HPACK compression won't change actual padding length on the wire. + // https://www.rfc-editor.org/rfc/rfc9204.html#section-4.1.2-2 + // h3's similar QPACK feature uses the same huffman table. + + switch method { + case PaddingMethodRepeatX: + return strings.Repeat("X", length) + case PaddingMethodTokenish: + paddingValue := GenerateTokenishPaddingBase62(length) + if paddingValue == "" { + return strings.Repeat("X", length) + } + return paddingValue + default: + return strings.Repeat("X", length) + } +} + +func ApplyPaddingToCookie(req *http.Request, name, value string) { + if req == nil || name == "" || value == "" { + return + } + req.AddCookie(&http.Cookie{ + Name: name, + Value: value, + Path: "/", + }) +} + +func ApplyPaddingToQuery(u *url.URL, key, value string) { + if u == nil || key == "" || value == "" { + return + } + q := u.Query() + q.Set(key, value) + u.RawQuery = q.Encode() +} + +func (c *Config) GetNormalizedXPaddingBytes() RangeConfig { + if c.XPaddingBytes == nil || c.XPaddingBytes.To == 0 { + return RangeConfig{ + From: 100, + To: 1000, + } + } + + return *c.XPaddingBytes +} + +func (c *Config) ApplyXPaddingToHeader(h http.Header, config XPaddingConfig) { + if h == nil { + return + } + + paddingValue := GeneratePadding(config.Method, config.Length) + + switch p := config.Placement; p.Placement { + case PlacementHeader: + h.Set(p.Header, paddingValue) + case PlacementQueryInHeader: + u, err := url.Parse(p.RawURL) + if err != nil || u == nil { + return + } + u.RawQuery = p.Key + "=" + paddingValue + h.Set(p.Header, u.String()) + } +} + +func (c *Config) ApplyXPaddingToRequest(req *http.Request, config XPaddingConfig) { + if req == nil { + return + } + if req.Header == nil { + req.Header = make(http.Header) + } + + placement := config.Placement.Placement + + if placement == PlacementHeader || placement == PlacementQueryInHeader { + c.ApplyXPaddingToHeader(req.Header, config) + return + } + + paddingValue := GeneratePadding(config.Method, config.Length) + + switch placement { + case PlacementCookie: + ApplyPaddingToCookie(req, config.Placement.Key, paddingValue) + case PlacementQuery: + ApplyPaddingToQuery(req.URL, config.Placement.Key, paddingValue) + } +} + +func (c *Config) ExtractXPaddingFromRequest(req *http.Request, obfsMode bool) (string, string) { + if req == nil { + return "", "" + } + + if !obfsMode { + referrer := req.Header.Get("Referer") + + if referrer != "" { + if referrerURL, err := url.Parse(referrer); err == nil { + paddingValue := referrerURL.Query().Get("x_padding") + paddingPlacement := PlacementQueryInHeader + "=Referer, key=x_padding" + return paddingValue, paddingPlacement + } + } else { + paddingValue := req.URL.Query().Get("x_padding") + return paddingValue, PlacementQuery + ", key=x_padding" + } + } + + key := c.XPaddingKey + header := c.XPaddingHeader + + if cookie, err := req.Cookie(key); err == nil { + if cookie != nil && cookie.Value != "" { + paddingValue := cookie.Value + paddingPlacement := PlacementCookie + ", key=" + key + return paddingValue, paddingPlacement + } + } + + headerValue := req.Header.Get(header) + + if headerValue != "" { + if c.XPaddingPlacement == PlacementHeader { + paddingPlacement := PlacementHeader + "=" + header + return headerValue, paddingPlacement + } + + if parsedURL, err := url.Parse(headerValue); err == nil { + paddingPlacement := PlacementQueryInHeader + "=" + header + ", key=" + key + + return parsedURL.Query().Get(key), paddingPlacement + } + } + + queryValue := req.URL.Query().Get(key) + + if queryValue != "" { + paddingPlacement := PlacementQuery + ", key=" + key + return queryValue, paddingPlacement + } + + return "", "" +} + +func (c *Config) IsPaddingValid(paddingValue string, from, to int32, method PaddingMethod) bool { + if paddingValue == "" { + return false + } + if to <= 0 { + r := c.GetNormalizedXPaddingBytes() + from, to = r.From, r.To + } + + switch method { + case PaddingMethodRepeatX: + n := int32(len(paddingValue)) + return n >= from && n <= to + case PaddingMethodTokenish: + const tolerance = int32(validationTolerance) + + n := int32(hpack.HuffmanEncodeLength(paddingValue)) + f := from - tolerance + t := to + tolerance + if f < 0 { + f = 0 + } + return n >= f && n <= t + default: + n := int32(len(paddingValue)) + return n >= from && n <= to + } +} diff --git a/transport/internet/tls/config.go b/transport/internet/tls/config.go index 287134b7..e5534b47 100644 --- a/transport/internet/tls/config.go +++ b/transport/internet/tls/config.go @@ -289,9 +289,6 @@ func (r *RandCarrier) verifyPeerCert(rawCerts [][]byte, verifiedChains [][]*x509 if len(certs) == 0 { return errors.New("unexpected certs") } - if certs[0].IsCA { - slices.Reverse(certs) - } // directly return success if pinned cert is leaf // or replace RootCAs if pinned cert is CA (and can be used in VerifyPeerCertByName) @@ -384,6 +381,7 @@ func (c *Config) GetTLSConfig(opts ...Option) *tls.Config { PinnedPeerCertSha256: c.PinnedPeerCertSha256, } config := &tls.Config{ + InsecureSkipVerify: c.AllowInsecure, Rand: randCarrier, ClientSessionCache: globalSessionCache, RootCAs: root, @@ -557,14 +555,19 @@ const ( ) func verifyChain(certs []*x509.Certificate, pinnedPeerCertSha256 [][]byte) (verifyResult, *x509.Certificate) { + leafHash := GenerateCertHash(certs[0]) + for _, c := range pinnedPeerCertSha256 { + if hmac.Equal(leafHash, c) { + return foundLeaf, nil + } + } + certs = certs[1:] // skip leaf for _, cert := range certs { certHash := GenerateCertHash(cert) for _, c := range pinnedPeerCertSha256 { if hmac.Equal(certHash, c) { if cert.IsCA { return foundCA, cert - } else { - return foundLeaf, cert } } } diff --git a/transport/internet/tls/config.pb.go b/transport/internet/tls/config.pb.go index 25376977..fa86b0b8 100644 --- a/transport/internet/tls/config.pb.go +++ b/transport/internet/tls/config.pb.go @@ -177,7 +177,8 @@ func (x *Certificate) GetBuildChain() bool { } type Config struct { - state protoimpl.MessageState `protogen:"open.v1"` + state protoimpl.MessageState `protogen:"open.v1"` + AllowInsecure bool `protobuf:"varint,1,opt,name=allow_insecure,json=allowInsecure,proto3" json:"allow_insecure,omitempty"` // List of certificates to be served on server. Certificate []*Certificate `protobuf:"bytes,2,rep,name=certificate,proto3" json:"certificate,omitempty"` // Override server name. @@ -241,6 +242,13 @@ func (*Config) Descriptor() ([]byte, []int) { return file_transport_internet_tls_config_proto_rawDescGZIP(), []int{1} } +func (x *Config) GetAllowInsecure() bool { + if x != nil { + return x.AllowInsecure + } + return false +} + func (x *Config) GetCertificate() []*Certificate { if x != nil { return x.Certificate @@ -385,8 +393,9 @@ const file_transport_internet_tls_config_proto_rawDesc = "" + "\x05Usage\x12\x10\n" + "\fENCIPHERMENT\x10\x00\x12\x14\n" + "\x10AUTHORITY_VERIFY\x10\x01\x12\x13\n" + - "\x0fAUTHORITY_ISSUE\x10\x02\"\xce\x06\n" + - "\x06Config\x12J\n" + + "\x0fAUTHORITY_ISSUE\x10\x02\"\xf5\x06\n" + + "\x06Config\x12%\n" + + "\x0eallow_insecure\x18\x01 \x01(\bR\rallowInsecure\x12J\n" + "\vcertificate\x18\x02 \x03(\v2(.xray.transport.internet.tls.CertificateR\vcertificate\x12\x1f\n" + "\vserver_name\x18\x03 \x01(\tR\n" + "serverName\x12#\n" + diff --git a/transport/internet/tls/config.proto b/transport/internet/tls/config.proto index 52b54b1a..34c918aa 100644 --- a/transport/internet/tls/config.proto +++ b/transport/internet/tls/config.proto @@ -38,6 +38,8 @@ message Certificate { } message Config { + bool allow_insecure = 1; + // List of certificates to be served on server. repeated Certificate certificate = 2; diff --git a/transport/internet/tls/ech.go b/transport/internet/tls/ech.go index f9c25c2a..220685d5 100644 --- a/transport/internet/tls/ech.go +++ b/transport/internet/tls/ech.go @@ -257,7 +257,9 @@ func dnsQuery(server string, domain string, sockopt *internet.SocketConfig) ([]b } req.Header.Set("Accept", "application/dns-message") req.Header.Set("Content-Type", "application/dns-message") - req.Header.Set("X-Padding", strings.Repeat("X", int(crypto.RandBetween(100, 1000)))) + req.Header.Set("User-Agent", utils.ChromeUA) + req.Header.Set("X-Padding", utils.H2Base62Pad(crypto.RandBetween(100, 1000))) + resp, err := client.Do(req) if err != nil { return nil, 0, err diff --git a/transport/internet/tls/pin.go b/transport/internet/tls/pin.go index 060f5d73..54029572 100644 --- a/transport/internet/tls/pin.go +++ b/transport/internet/tls/pin.go @@ -4,28 +4,8 @@ import ( "crypto/sha256" "crypto/x509" "encoding/hex" - "encoding/pem" ) -func CalculatePEMLeafCertSHA256Hash(certContent []byte) (string, error) { - var leafCert *x509.Certificate - for { - var err error - block, remain := pem.Decode(certContent) - if block == nil { - break - } - leafCert, err = x509.ParseCertificate(block.Bytes) - if err != nil { - return "", err - } - certContent = remain - } - certHash := GenerateCertHash(leafCert) - certHashHex := hex.EncodeToString(certHash) - return certHashHex, nil -} - // []byte must be ASN.1 DER content func GenerateCertHash[T *x509.Certificate | []byte](cert T) []byte { var out [32]byte @@ -37,3 +17,14 @@ func GenerateCertHash[T *x509.Certificate | []byte](cert T) []byte { } return out[:] } + +func GenerateCertHashHex[T *x509.Certificate | []byte](cert T) string { + var out [32]byte + switch v := any(cert).(type) { + case *x509.Certificate: + out = sha256.Sum256(v.Raw) + case []byte: + out = sha256.Sum256(v) + } + return hex.EncodeToString(out[:]) +} diff --git a/transport/internet/tls/tls.go b/transport/internet/tls/tls.go index e46f9687..e183c82f 100644 --- a/transport/internet/tls/tls.go +++ b/transport/internet/tls/tls.go @@ -126,6 +126,10 @@ func UClient(c net.Conn, config *tls.Config, fingerprint *utls.ClientHelloID) ne return &UConn{UConn: utlsConn} } +func GeneraticUClient(c net.Conn, config *tls.Config) *utls.UConn { + return utls.UClient(c, copyConfig(config), utls.HelloChrome_Auto) +} + func copyConfig(c *tls.Config) *utls.Config { return &utls.Config{ Rand: c.Rand, diff --git a/transport/internet/udp/dialer.go b/transport/internet/udp/dialer.go index ff9052a9..e36c3c55 100644 --- a/transport/internet/udp/dialer.go +++ b/transport/internet/udp/dialer.go @@ -4,6 +4,7 @@ import ( "context" "github.com/amnezia-vpn/amnezia-xray-core/common" + "github.com/amnezia-vpn/amnezia-xray-core/common/errors" "github.com/amnezia-vpn/amnezia-xray-core/common/net" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet/stat" @@ -20,7 +21,24 @@ func init() { if err != nil { return nil, err } + + if streamSettings != nil && streamSettings.UdpmaskManager != nil { + wrapper, ok := conn.(*internet.PacketConnWrapper) + if !ok { + conn.Close() + return nil, errors.New("conn is not PacketConnWrapper") + } + + raw := wrapper.Conn + + wrapper.Conn, err = streamSettings.UdpmaskManager.WrapPacketConnClient(raw) + if err != nil { + raw.Close() + return nil, errors.New("mask err").Base(err) + } + } + // TODO: handle dialer options - return stat.Connection(conn), nil + return conn, nil })) } diff --git a/transport/internet/udp/hub.go b/transport/internet/udp/hub.go index 6c0e0614..868eeed4 100644 --- a/transport/internet/udp/hub.go +++ b/transport/internet/udp/hub.go @@ -25,7 +25,8 @@ func HubReceiveOriginalDestination(r bool) HubOption { } type Hub struct { - conn *net.UDPConn + conn net.PacketConn + udpConn *net.UDPConn cache chan *udp.Packet capacity int recvOrigDest bool @@ -56,15 +57,27 @@ func ListenUDP(ctx context.Context, address net.Address, port net.Port, streamSe hub.recvOrigDest = true } - udpConn, err := internet.ListenSystemPacket(ctx, &net.UDPAddr{ + var err error + hub.conn, err = internet.ListenSystemPacket(ctx, &net.UDPAddr{ IP: address.IP(), Port: int(port), }, sockopt) if err != nil { return nil, err } + + raw := hub.conn + + if streamSettings.UdpmaskManager != nil { + hub.conn, err = streamSettings.UdpmaskManager.WrapPacketConnServer(raw) + if err != nil { + raw.Close() + return nil, errors.New("mask err").Base(err) + } + } + errors.LogInfo(ctx, "listening UDP on ", address, ":", port) - hub.conn = udpConn.(*net.UDPConn) + hub.udpConn, _ = hub.conn.(*net.UDPConn) hub.cache = make(chan *udp.Packet, hub.capacity) go hub.start() @@ -78,7 +91,7 @@ func (h *Hub) Close() error { } func (h *Hub) WriteTo(payload []byte, dest net.Destination) (int, error) { - return h.conn.WriteToUDP(payload, &net.UDPAddr{ + return h.conn.WriteTo(payload, &net.UDPAddr{ IP: dest.Address.IP(), Port: int(dest.Port), }) @@ -93,10 +106,21 @@ func (h *Hub) start() { for { buffer := buf.New() var noob int - var addr *net.UDPAddr + var udpAddr *net.UDPAddr rawBytes := buffer.Extend(buf.Size) - n, noob, _, addr, err := ReadUDPMsg(h.conn, rawBytes, oobBytes) + var n int + var err error + if h.udpConn != nil { + n, noob, _, udpAddr, err = ReadUDPMsg(h.udpConn, rawBytes, oobBytes) + } else { + var addr net.Addr + n, addr, err = h.conn.ReadFrom(rawBytes) + if err == nil { + udpAddr = addr.(*net.UDPAddr) + } + } + if err != nil { errors.LogInfoInner(context.Background(), err, "failed to read UDP msg") buffer.Release() @@ -111,7 +135,7 @@ func (h *Hub) start() { payload := &udp.Packet{ Payload: buffer, - Source: net.UDPDestination(net.IPAddress(addr.IP), net.Port(addr.Port)), + Source: net.UDPDestination(net.IPAddress(udpAddr.IP), net.Port(udpAddr.Port)), } if h.recvOrigDest && noob > 0 { payload.Target = RetrieveOriginalDest(oobBytes[:noob]) diff --git a/transport/internet/websocket/config.go b/transport/internet/websocket/config.go index 67e8a08c..862c2301 100644 --- a/transport/internet/websocket/config.go +++ b/transport/internet/websocket/config.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/amnezia-vpn/amnezia-xray-core/common" + "github.com/amnezia-vpn/amnezia-xray-core/common/utils" "github.com/amnezia-vpn/amnezia-xray-core/transport/internet" ) @@ -23,6 +24,9 @@ func (c *Config) GetRequestHeader() http.Header { for k, v := range c.Header { header.Add(k, v) } + if header.Get("User-Agent") == "" { + header.Set("User-Agent", utils.ChromeUA) + } return header }