Files
amneziawg-windows-client/manager/ipc_server.go
2025-03-10 14:16:49 +07:00

542 lines
12 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package manager
import (
"bytes"
"encoding/gob"
"fmt"
"io"
"log"
"os"
"sync"
"sync/atomic"
"time"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"github.com/amnezia-vpn/amneziawg-windows-client/updater"
"github.com/amnezia-vpn/amneziawg-windows/conf"
"github.com/amnezia-vpn/amneziawg-windows/services"
)
var (
managerServices = make(map[*ManagerService]bool)
managerServicesLock sync.RWMutex
haveQuit uint32
quitManagersChan = make(chan struct{}, 1)
)
type ManagerService struct {
events *os.File
eventLock sync.Mutex
elevatedToken windows.Token
}
func (s *ManagerService) StoredConfig(tunnelName string) (*conf.Config, error) {
conf, err := conf.LoadFromName(tunnelName)
if err != nil {
return nil, err
}
if s.elevatedToken == 0 {
conf.Redact()
}
return conf, nil
}
func (s *ManagerService) RuntimeConfig(tunnelName string) (*conf.Config, error) {
storedConfig, err := conf.LoadFromName(tunnelName)
if err != nil {
return nil, err
}
pipe, err := connectTunnelServicePipe(tunnelName)
if err != nil {
return nil, err
}
pipe.SetDeadline(time.Now().Add(time.Second * 2))
_, err = pipe.Write([]byte("get=1\n\n"))
if err == windows.ERROR_NO_DATA {
log.Println("IPC pipe closed unexpectedly, so reopening")
pipe.Unlock()
disconnectTunnelServicePipe(tunnelName)
pipe, err = connectTunnelServicePipe(tunnelName)
if err != nil {
return nil, err
}
pipe.SetDeadline(time.Now().Add(time.Second * 2))
_, err = pipe.Write([]byte("get=1\n\n"))
}
if err != nil {
pipe.Unlock()
disconnectTunnelServicePipe(tunnelName)
return nil, err
}
conf, err := conf.FromUAPI(pipe, storedConfig)
pipe.Unlock()
if err != nil {
return nil, err
}
if s.elevatedToken == 0 {
conf.Redact()
}
return conf, nil
}
func (s *ManagerService) Start(tunnelName string) error {
c, err := conf.LoadFromName(tunnelName)
if err != nil {
return err
}
// Figure out which tunnels have intersecting addresses/routes and stop those.
trackedTunnelsLock.Lock()
tt := make([]string, 0, len(trackedTunnels))
var inTransition string
for t, state := range trackedTunnels {
c2, err := conf.LoadFromName(t)
if err != nil || !c.IntersectsWith(c2) {
// If we can't get the config, assume it doesn't intersect.
continue
}
tt = append(tt, t)
if len(t) > 0 && (state == TunnelStarting || state == TunnelUnknown) {
inTransition = t
break
}
}
trackedTunnelsLock.Unlock()
if len(inTransition) != 0 {
return fmt.Errorf("Please allow the tunnel %s to finish activating", inTransition)
}
// Stop those intersecting tunnels asynchronously.
go func() {
for _, t := range tt {
s.Stop(t)
}
for _, t := range tt {
state, err := s.State(t)
if err == nil && (state == TunnelStarted || state == TunnelStarting) {
log.Printf("[%s] Trying again to stop zombie tunnel", t)
s.Stop(t)
time.Sleep(time.Millisecond * 100)
}
}
}()
// After the stop process has begun, but before it's finished, we install the new one.
path, err := c.Path()
if err != nil {
return err
}
return InstallTunnel(path)
}
func (s *ManagerService) Stop(tunnelName string) error {
err := UninstallTunnel(tunnelName)
if err == windows.ERROR_SERVICE_DOES_NOT_EXIST {
_, notExistsError := conf.LoadFromName(tunnelName)
if notExistsError == nil {
return nil
}
}
return err
}
func (s *ManagerService) WaitForStop(tunnelName string) error {
serviceName, err := services.ServiceNameOfTunnel(tunnelName)
if err != nil {
return err
}
m, err := serviceManager()
if err != nil {
return err
}
for {
service, err := m.OpenService(serviceName)
if err == nil || err == windows.ERROR_SERVICE_MARKED_FOR_DELETE {
service.Close()
time.Sleep(time.Second / 3)
} else {
return nil
}
}
}
func (s *ManagerService) Delete(tunnelName string) error {
if s.elevatedToken == 0 {
return windows.ERROR_ACCESS_DENIED
}
err := s.Stop(tunnelName)
if err != nil {
return err
}
return conf.DeleteName(tunnelName)
}
func (s *ManagerService) State(tunnelName string) (TunnelState, error) {
serviceName, err := services.ServiceNameOfTunnel(tunnelName)
if err != nil {
return 0, err
}
m, err := serviceManager()
if err != nil {
return 0, err
}
service, err := m.OpenService(serviceName)
if err != nil {
return TunnelStopped, nil
}
defer service.Close()
status, err := service.Query()
if err != nil {
return TunnelUnknown, nil
}
switch status.State {
case svc.Stopped:
return TunnelStopped, nil
case svc.StopPending:
return TunnelStopping, nil
case svc.Running:
return TunnelStarted, nil
case svc.StartPending:
return TunnelStarting, nil
default:
return TunnelUnknown, nil
}
}
func (s *ManagerService) GlobalState() TunnelState {
return trackedTunnelsGlobalState()
}
func (s *ManagerService) Create(tunnelConfig *conf.Config) (*Tunnel, error) {
if s.elevatedToken == 0 {
return nil, windows.ERROR_ACCESS_DENIED
}
err := tunnelConfig.Save(true)
if err != nil {
return nil, err
}
return &Tunnel{tunnelConfig.Name}, nil
// TODO: handle already existing situation
// TODO: handle already running and existing situation
}
func (s *ManagerService) Tunnels() ([]Tunnel, error) {
names, err := conf.ListConfigNames()
if err != nil {
return nil, err
}
tunnels := make([]Tunnel, len(names))
for i := 0; i < len(tunnels); i++ {
tunnels[i].Name = names[i]
}
return tunnels, nil
// TODO: account for running ones that aren't in the configuration store somehow
}
func (s *ManagerService) Quit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) {
if s.elevatedToken == 0 {
return false, windows.ERROR_ACCESS_DENIED
}
if !atomic.CompareAndSwapUint32(&haveQuit, 0, 1) {
return true, nil
}
// Work around potential race condition of delivering messages to the wrong process by removing from notifications.
managerServicesLock.Lock()
s.eventLock.Lock()
s.events = nil
s.eventLock.Unlock()
delete(managerServices, s)
managerServicesLock.Unlock()
if stopTunnelsOnQuit {
names, err := conf.ListConfigNames()
if err != nil {
return false, err
}
for _, name := range names {
UninstallTunnel(name)
}
}
quitManagersChan <- struct{}{}
return false, nil
}
func (s *ManagerService) UpdateState() UpdateState {
return updateState
}
func (s *ManagerService) Update() {
if s.elevatedToken == 0 {
return
}
progress := updater.DownloadVerifyAndExecute(uintptr(s.elevatedToken))
go func() {
for {
dp := <-progress
IPCServerNotifyUpdateProgress(dp)
if dp.Complete || dp.Error != nil {
return
}
}
}()
}
func (s *ManagerService) ServeConn(reader io.Reader, writer io.Writer) {
decoder := gob.NewDecoder(reader)
encoder := gob.NewEncoder(writer)
for {
var methodType MethodType
err := decoder.Decode(&methodType)
if err != nil {
return
}
switch methodType {
case StoredConfigMethodType:
var tunnelName string
err := decoder.Decode(&tunnelName)
if err != nil {
return
}
config, retErr := s.StoredConfig(tunnelName)
if config == nil {
config = &conf.Config{}
}
err = encoder.Encode(*config)
if err != nil {
return
}
err = encoder.Encode(errToString(retErr))
if err != nil {
return
}
case RuntimeConfigMethodType:
var tunnelName string
err := decoder.Decode(&tunnelName)
if err != nil {
return
}
config, retErr := s.RuntimeConfig(tunnelName)
if config == nil {
config = &conf.Config{}
}
err = encoder.Encode(*config)
if err != nil {
return
}
err = encoder.Encode(errToString(retErr))
if err != nil {
return
}
case StartMethodType:
var tunnelName string
err := decoder.Decode(&tunnelName)
if err != nil {
return
}
retErr := s.Start(tunnelName)
err = encoder.Encode(errToString(retErr))
if err != nil {
return
}
case StopMethodType:
var tunnelName string
err := decoder.Decode(&tunnelName)
if err != nil {
return
}
retErr := s.Stop(tunnelName)
err = encoder.Encode(errToString(retErr))
if err != nil {
return
}
case WaitForStopMethodType:
var tunnelName string
err := decoder.Decode(&tunnelName)
if err != nil {
return
}
retErr := s.WaitForStop(tunnelName)
err = encoder.Encode(errToString(retErr))
if err != nil {
return
}
case DeleteMethodType:
var tunnelName string
err := decoder.Decode(&tunnelName)
if err != nil {
return
}
retErr := s.Delete(tunnelName)
err = encoder.Encode(errToString(retErr))
if err != nil {
return
}
case StateMethodType:
var tunnelName string
err := decoder.Decode(&tunnelName)
if err != nil {
return
}
state, retErr := s.State(tunnelName)
err = encoder.Encode(state)
if err != nil {
return
}
err = encoder.Encode(errToString(retErr))
if err != nil {
return
}
case GlobalStateMethodType:
state := s.GlobalState()
err = encoder.Encode(state)
if err != nil {
return
}
case CreateMethodType:
var config conf.Config
err := decoder.Decode(&config)
if err != nil {
return
}
tunnel, retErr := s.Create(&config)
if tunnel == nil {
tunnel = &Tunnel{}
}
err = encoder.Encode(tunnel)
if err != nil {
return
}
err = encoder.Encode(errToString(retErr))
if err != nil {
return
}
case TunnelsMethodType:
tunnels, retErr := s.Tunnels()
err = encoder.Encode(tunnels)
if err != nil {
return
}
err = encoder.Encode(errToString(retErr))
if err != nil {
return
}
case QuitMethodType:
var stopTunnelsOnQuit bool
err := decoder.Decode(&stopTunnelsOnQuit)
if err != nil {
return
}
alreadyQuit, retErr := s.Quit(stopTunnelsOnQuit)
err = encoder.Encode(alreadyQuit)
if err != nil {
return
}
err = encoder.Encode(errToString(retErr))
if err != nil {
return
}
case UpdateStateMethodType:
updateState := s.UpdateState()
err = encoder.Encode(updateState)
if err != nil {
return
}
case UpdateMethodType:
s.Update()
default:
return
}
}
}
func IPCServerListen(reader, writer, events *os.File, elevatedToken windows.Token) {
service := &ManagerService{
events: events,
elevatedToken: elevatedToken,
}
go func() {
managerServicesLock.Lock()
managerServices[service] = true
managerServicesLock.Unlock()
service.ServeConn(reader, writer)
managerServicesLock.Lock()
service.eventLock.Lock()
service.events = nil
service.eventLock.Unlock()
delete(managerServices, service)
managerServicesLock.Unlock()
}()
}
func notifyAll(notificationType NotificationType, adminOnly bool, ifaces ...any) {
if len(managerServices) == 0 {
return
}
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
err := encoder.Encode(notificationType)
if err != nil {
return
}
for _, iface := range ifaces {
err = encoder.Encode(iface)
if err != nil {
return
}
}
managerServicesLock.RLock()
for m := range managerServices {
if m.elevatedToken == 0 && adminOnly {
continue
}
go func(m *ManagerService) {
m.eventLock.Lock()
defer m.eventLock.Unlock()
if m.events != nil {
m.events.SetWriteDeadline(time.Now().Add(time.Second))
m.events.Write(buf.Bytes())
}
}(m)
}
managerServicesLock.RUnlock()
}
func errToString(err error) string {
if err == nil {
return ""
}
return err.Error()
}
func IPCServerNotifyTunnelChange(name string, state TunnelState, err error) {
notifyAll(TunnelChangeNotificationType, false, name, state, trackedTunnelsGlobalState(), errToString(err))
}
func IPCServerNotifyTunnelsChange() {
notifyAll(TunnelsChangeNotificationType, false)
}
func IPCServerNotifyUpdateFound(state UpdateState) {
notifyAll(UpdateFoundNotificationType, false, state)
}
func IPCServerNotifyUpdateProgress(dp updater.DownloadProgress) {
notifyAll(UpdateProgressNotificationType, true, dp.Activity, dp.BytesDownloaded, dp.BytesTotal, errToString(dp.Error), dp.Complete)
}
func IPCServerNotifyManagerStopping() {
notifyAll(ManagerStoppingNotificationType, false)
time.Sleep(time.Millisecond * 200)
}