mirror of
https://github.com/amnezia-vpn/amneziawg-windows-client.git
synced 2026-05-17 08:15:44 +03:00
542 lines
12 KiB
Go
542 lines
12 KiB
Go
/* SPDX-License-Identifier: MIT
|
||
*
|
||
* Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
|
||
*/
|
||
|
||
package manager
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/gob"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"os"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"golang.org/x/sys/windows"
|
||
"golang.org/x/sys/windows/svc"
|
||
|
||
"github.com/amnezia-vpn/amneziawg-windows-client/updater"
|
||
"github.com/amnezia-vpn/amneziawg-windows/conf"
|
||
"github.com/amnezia-vpn/amneziawg-windows/services"
|
||
)
|
||
|
||
var (
|
||
managerServices = make(map[*ManagerService]bool)
|
||
managerServicesLock sync.RWMutex
|
||
haveQuit uint32
|
||
quitManagersChan = make(chan struct{}, 1)
|
||
)
|
||
|
||
type ManagerService struct {
|
||
events *os.File
|
||
eventLock sync.Mutex
|
||
elevatedToken windows.Token
|
||
}
|
||
|
||
func (s *ManagerService) StoredConfig(tunnelName string) (*conf.Config, error) {
|
||
conf, err := conf.LoadFromName(tunnelName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if s.elevatedToken == 0 {
|
||
conf.Redact()
|
||
}
|
||
return conf, nil
|
||
}
|
||
|
||
func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) {
|
||
storedConfig, err := conf.LoadFromName(tunnelName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
pipe, err := connectTunnelServicePipe(tunnelName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
pipe.SetDeadline(time.Now().Add(time.Second * 2))
|
||
_, err = pipe.Write([]byte("get=1\n\n"))
|
||
if err == windows.ERROR_NO_DATA {
|
||
log.Println("IPC pipe closed unexpectedly, so reopening")
|
||
pipe.Unlock()
|
||
disconnectTunnelServicePipe(tunnelName)
|
||
pipe, err = connectTunnelServicePipe(tunnelName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
pipe.SetDeadline(time.Now().Add(time.Second * 2))
|
||
_, err = pipe.Write([]byte("get=1\n\n"))
|
||
}
|
||
if err != nil {
|
||
pipe.Unlock()
|
||
disconnectTunnelServicePipe(tunnelName)
|
||
return nil, err
|
||
}
|
||
conf, err := conf.FromUAPI(pipe, storedConfig)
|
||
pipe.Unlock()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if s.elevatedToken == 0 {
|
||
conf.Redact()
|
||
}
|
||
return conf, nil
|
||
}
|
||
|
||
func (s *ManagerService) Start(tunnelName string) error {
|
||
c, err := conf.LoadFromName(tunnelName)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// Figure out which tunnels have intersecting addresses/routes and stop those.
|
||
trackedTunnelsLock.Lock()
|
||
tt := make([]string, 0, len(trackedTunnels))
|
||
var inTransition string
|
||
for t, state := range trackedTunnels {
|
||
c2, err := conf.LoadFromName(t)
|
||
if err != nil || !c.IntersectsWith(c2) {
|
||
// If we can't get the config, assume it doesn't intersect.
|
||
continue
|
||
}
|
||
tt = append(tt, t)
|
||
if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) {
|
||
inTransition = t
|
||
break
|
||
}
|
||
}
|
||
trackedTunnelsLock.Unlock()
|
||
if len(inTransition) != 0 {
|
||
return fmt.Errorf("Please allow the tunnel ‘%s’ to finish activating", inTransition)
|
||
}
|
||
|
||
// Stop those intersecting tunnels asynchronously.
|
||
go func() {
|
||
for _, t := range tt {
|
||
s.Stop(t)
|
||
}
|
||
for _, t := range tt {
|
||
state, err := s.State(t)
|
||
if err == nil && (state == TunnelStarted || state == TunnelStarting) {
|
||
log.Printf("[%s] Trying again to stop zombie tunnel", t)
|
||
s.Stop(t)
|
||
time.Sleep(time.Millisecond * 100)
|
||
}
|
||
}
|
||
}()
|
||
// After the stop process has begun, but before it's finished, we install the new one.
|
||
path, err := c.Path()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return InstallTunnel(path)
|
||
}
|
||
|
||
func (s *ManagerService) Stop(tunnelName string) error {
|
||
err := UninstallTunnel(tunnelName)
|
||
if err == windows.ERROR_SERVICE_DOES_NOT_EXIST {
|
||
_, notExistsError := conf.LoadFromName(tunnelName)
|
||
if notExistsError == nil {
|
||
return nil
|
||
}
|
||
}
|
||
return err
|
||
}
|
||
|
||
func (s *ManagerService) WaitForStop(tunnelName string) error {
|
||
serviceName, err := services.ServiceNameOfTunnel(tunnelName)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
m, err := serviceManager()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
for {
|
||
service, err := m.OpenService(serviceName)
|
||
if err == nil || err == windows.ERROR_SERVICE_MARKED_FOR_DELETE {
|
||
service.Close()
|
||
time.Sleep(time.Second / 3)
|
||
} else {
|
||
return nil
|
||
}
|
||
}
|
||
}
|
||
|
||
func (s *ManagerService) Delete(tunnelName string) error {
|
||
if s.elevatedToken == 0 {
|
||
return windows.ERROR_ACCESS_DENIED
|
||
}
|
||
err := s.Stop(tunnelName)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return conf.DeleteName(tunnelName)
|
||
}
|
||
|
||
func (s *ManagerService) State(tunnelName string) (TunnelState, error) {
|
||
serviceName, err := services.ServiceNameOfTunnel(tunnelName)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
m, err := serviceManager()
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
service, err := m.OpenService(serviceName)
|
||
if err != nil {
|
||
return TunnelStopped, nil
|
||
}
|
||
defer service.Close()
|
||
status, err := service.Query()
|
||
if err != nil {
|
||
return TunnelUnknown, nil
|
||
}
|
||
switch status.State {
|
||
case svc.Stopped:
|
||
return TunnelStopped, nil
|
||
case svc.StopPending:
|
||
return TunnelStopping, nil
|
||
case svc.Running:
|
||
return TunnelStarted, nil
|
||
case svc.StartPending:
|
||
return TunnelStarting, nil
|
||
default:
|
||
return TunnelUnknown, nil
|
||
}
|
||
}
|
||
|
||
func (s *ManagerService) GlobalState() TunnelState {
|
||
return trackedTunnelsGlobalState()
|
||
}
|
||
|
||
func (s *ManagerService) Create(tunnelConfig *conf.Config) (*Tunnel, error) {
|
||
if s.elevatedToken == 0 {
|
||
return nil, windows.ERROR_ACCESS_DENIED
|
||
}
|
||
err := tunnelConfig.Save(true)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &Tunnel{tunnelConfig.Name}, nil
|
||
// TODO: handle already existing situation
|
||
// TODO: handle already running and existing situation
|
||
}
|
||
|
||
func (s *ManagerService) Tunnels() ([]Tunnel, error) {
|
||
names, err := conf.ListConfigNames()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
tunnels := make([]Tunnel, len(names))
|
||
for i := 0; i < len(tunnels); i++ {
|
||
tunnels[i].Name = names[i]
|
||
}
|
||
return tunnels, nil
|
||
// TODO: account for running ones that aren't in the configuration store somehow
|
||
}
|
||
|
||
func (s *ManagerService) Quit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) {
|
||
if s.elevatedToken == 0 {
|
||
return false, windows.ERROR_ACCESS_DENIED
|
||
}
|
||
if !atomic.CompareAndSwapUint32(&haveQuit, 0, 1) {
|
||
return true, nil
|
||
}
|
||
|
||
// Work around potential race condition of delivering messages to the wrong process by removing from notifications.
|
||
managerServicesLock.Lock()
|
||
s.eventLock.Lock()
|
||
s.events = nil
|
||
s.eventLock.Unlock()
|
||
delete(managerServices, s)
|
||
managerServicesLock.Unlock()
|
||
|
||
if stopTunnelsOnQuit {
|
||
names, err := conf.ListConfigNames()
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
for _, name := range names {
|
||
UninstallTunnel(name)
|
||
}
|
||
}
|
||
|
||
quitManagersChan <- struct{}{}
|
||
return false, nil
|
||
}
|
||
|
||
func (s *ManagerService) UpdateState() UpdateState {
|
||
return updateState
|
||
}
|
||
|
||
func (s *ManagerService) Update() {
|
||
if s.elevatedToken == 0 {
|
||
return
|
||
}
|
||
progress := updater.DownloadVerifyAndExecute(uintptr(s.elevatedToken))
|
||
go func() {
|
||
for {
|
||
dp := <-progress
|
||
IPCServerNotifyUpdateProgress(dp)
|
||
if dp.Complete || dp.Error != nil {
|
||
return
|
||
}
|
||
}
|
||
}()
|
||
}
|
||
|
||
func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) {
|
||
decoder := gob.NewDecoder(reader)
|
||
encoder := gob.NewEncoder(writer)
|
||
for {
|
||
var methodType MethodType
|
||
err := decoder.Decode(&methodType)
|
||
if err != nil {
|
||
return
|
||
}
|
||
switch methodType {
|
||
case StoredConfigMethodType:
|
||
var tunnelName string
|
||
err := decoder.Decode(&tunnelName)
|
||
if err != nil {
|
||
return
|
||
}
|
||
config, retErr := s.StoredConfig(tunnelName)
|
||
if config == nil {
|
||
config = &conf.Config{}
|
||
}
|
||
err = encoder.Encode(*config)
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = encoder.Encode(errToString(retErr))
|
||
if err != nil {
|
||
return
|
||
}
|
||
case RuntimeConfigMethodType:
|
||
var tunnelName string
|
||
err := decoder.Decode(&tunnelName)
|
||
if err != nil {
|
||
return
|
||
}
|
||
config, retErr := s.RuntimeConfig(tunnelName)
|
||
if config == nil {
|
||
config = &conf.Config{}
|
||
}
|
||
err = encoder.Encode(*config)
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = encoder.Encode(errToString(retErr))
|
||
if err != nil {
|
||
return
|
||
}
|
||
case StartMethodType:
|
||
var tunnelName string
|
||
err := decoder.Decode(&tunnelName)
|
||
if err != nil {
|
||
return
|
||
}
|
||
retErr := s.Start(tunnelName)
|
||
err = encoder.Encode(errToString(retErr))
|
||
if err != nil {
|
||
return
|
||
}
|
||
case StopMethodType:
|
||
var tunnelName string
|
||
err := decoder.Decode(&tunnelName)
|
||
if err != nil {
|
||
return
|
||
}
|
||
retErr := s.Stop(tunnelName)
|
||
err = encoder.Encode(errToString(retErr))
|
||
if err != nil {
|
||
return
|
||
}
|
||
case WaitForStopMethodType:
|
||
var tunnelName string
|
||
err := decoder.Decode(&tunnelName)
|
||
if err != nil {
|
||
return
|
||
}
|
||
retErr := s.WaitForStop(tunnelName)
|
||
err = encoder.Encode(errToString(retErr))
|
||
if err != nil {
|
||
return
|
||
}
|
||
case DeleteMethodType:
|
||
var tunnelName string
|
||
err := decoder.Decode(&tunnelName)
|
||
if err != nil {
|
||
return
|
||
}
|
||
retErr := s.Delete(tunnelName)
|
||
err = encoder.Encode(errToString(retErr))
|
||
if err != nil {
|
||
return
|
||
}
|
||
case StateMethodType:
|
||
var tunnelName string
|
||
err := decoder.Decode(&tunnelName)
|
||
if err != nil {
|
||
return
|
||
}
|
||
state, retErr := s.State(tunnelName)
|
||
err = encoder.Encode(state)
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = encoder.Encode(errToString(retErr))
|
||
if err != nil {
|
||
return
|
||
}
|
||
case GlobalStateMethodType:
|
||
state := s.GlobalState()
|
||
err = encoder.Encode(state)
|
||
if err != nil {
|
||
return
|
||
}
|
||
case CreateMethodType:
|
||
var config conf.Config
|
||
err := decoder.Decode(&config)
|
||
if err != nil {
|
||
return
|
||
}
|
||
tunnel, retErr := s.Create(&config)
|
||
if tunnel == nil {
|
||
tunnel = &Tunnel{}
|
||
}
|
||
err = encoder.Encode(tunnel)
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = encoder.Encode(errToString(retErr))
|
||
if err != nil {
|
||
return
|
||
}
|
||
case TunnelsMethodType:
|
||
tunnels, retErr := s.Tunnels()
|
||
err = encoder.Encode(tunnels)
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = encoder.Encode(errToString(retErr))
|
||
if err != nil {
|
||
return
|
||
}
|
||
case QuitMethodType:
|
||
var stopTunnelsOnQuit bool
|
||
err := decoder.Decode(&stopTunnelsOnQuit)
|
||
if err != nil {
|
||
return
|
||
}
|
||
alreadyQuit, retErr := s.Quit(stopTunnelsOnQuit)
|
||
err = encoder.Encode(alreadyQuit)
|
||
if err != nil {
|
||
return
|
||
}
|
||
err = encoder.Encode(errToString(retErr))
|
||
if err != nil {
|
||
return
|
||
}
|
||
case UpdateStateMethodType:
|
||
updateState := s.UpdateState()
|
||
err = encoder.Encode(updateState)
|
||
if err != nil {
|
||
return
|
||
}
|
||
case UpdateMethodType:
|
||
s.Update()
|
||
default:
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
func IPCServerListen(reader, writer, events *os.File, elevatedToken windows.Token) {
|
||
service := &ManagerService{
|
||
events: events,
|
||
elevatedToken: elevatedToken,
|
||
}
|
||
|
||
go func() {
|
||
managerServicesLock.Lock()
|
||
managerServices[service] = true
|
||
managerServicesLock.Unlock()
|
||
service.ServeConn(reader, writer)
|
||
managerServicesLock.Lock()
|
||
service.eventLock.Lock()
|
||
service.events = nil
|
||
service.eventLock.Unlock()
|
||
delete(managerServices, service)
|
||
managerServicesLock.Unlock()
|
||
}()
|
||
}
|
||
|
||
func notifyAll(notificationType NotificationType, adminOnly bool, ifaces ...any) {
|
||
if len(managerServices) == 0 {
|
||
return
|
||
}
|
||
|
||
var buf bytes.Buffer
|
||
encoder := gob.NewEncoder(&buf)
|
||
err := encoder.Encode(notificationType)
|
||
if err != nil {
|
||
return
|
||
}
|
||
for _, iface := range ifaces {
|
||
err = encoder.Encode(iface)
|
||
if err != nil {
|
||
return
|
||
}
|
||
}
|
||
|
||
managerServicesLock.RLock()
|
||
for m := range managerServices {
|
||
if m.elevatedToken == 0 && adminOnly {
|
||
continue
|
||
}
|
||
go func(m *ManagerService) {
|
||
m.eventLock.Lock()
|
||
defer m.eventLock.Unlock()
|
||
if m.events != nil {
|
||
m.events.SetWriteDeadline(time.Now().Add(time.Second))
|
||
m.events.Write(buf.Bytes())
|
||
}
|
||
}(m)
|
||
}
|
||
managerServicesLock.RUnlock()
|
||
}
|
||
|
||
func errToString(err error) string {
|
||
if err == nil {
|
||
return ""
|
||
}
|
||
return err.Error()
|
||
}
|
||
|
||
func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) {
|
||
notifyAll(TunnelChangeNotificationType, false, name, state, trackedTunnelsGlobalState(), errToString(err))
|
||
}
|
||
|
||
func IPCServerNotifyTunnelsChange() {
|
||
notifyAll(TunnelsChangeNotificationType, false)
|
||
}
|
||
|
||
func IPCServerNotifyUpdateFound(state UpdateState) {
|
||
notifyAll(UpdateFoundNotificationType, false, state)
|
||
}
|
||
|
||
func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) {
|
||
notifyAll(UpdateProgressNotificationType, true, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, errToString(dp.Error), dp.Complete)
|
||
}
|
||
|
||
func IPCServerNotifyManagerStopping() {
|
||
notifyAll(ManagerStoppingNotificationType, false)
|
||
time.Sleep(time.Millisecond * 200)
|
||
}
|