From ba4dd48d5a99b348fc0b4857152ab3ed5e4e76cf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 12 Dec 2025 15:45:39 +0200 Subject: [PATCH] gotd: ensure user is member of channels before starting getDifference loop --- pkg/connector/client.go | 1 + pkg/connector/handletelegram.go | 4 ++ pkg/connector/store/scoped_store.go | 23 ++++--- pkg/gotd/telegram/updates/config.go | 4 +- .../telegram/updates/internal/e2e/server.go | 6 ++ pkg/gotd/telegram/updates/manager.go | 60 +++++++++++++++---- pkg/gotd/telegram/updates/state.go | 11 ++-- pkg/gotd/telegram/updates/state_channel.go | 10 ++-- 8 files changed, 88 insertions(+), 31 deletions(-) diff --git a/pkg/connector/client.go b/pkg/connector/client.go index c489e33d..60e574cf 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -226,6 +226,7 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge dispatcher.OnPhoneCall(client.onPhoneCall) client.updatesManager = updates.New(updates.Config{ + OnNotChannelMember: client.onNotChannelMember, OnChannelTooLong: func(channelID int64) error { // TODO resync topics? res := tc.Bridge.QueueRemoteEvent(login, &simplevent.ChatResync{ diff --git a/pkg/connector/handletelegram.go b/pkg/connector/handletelegram.go index 4ecefbff..60145b3d 100644 --- a/pkg/connector/handletelegram.go +++ b/pkg/connector/handletelegram.go @@ -107,6 +107,10 @@ func (t *TelegramClient) selfLeaveChat(ctx context.Context, portalKey networkid. return nil } +func (t *TelegramClient) onNotChannelMember(ctx context.Context, channelID int64) error { + return t.selfLeaveChat(ctx, t.makePortalKeyFromID(ids.PeerTypeChannel, channelID, 0), fmt.Errorf("startup channel member check failed")) +} + func (t *TelegramClient) onUpdateChannel(ctx context.Context, e tg.Entities, update *tg.UpdateChannel) error { log := zerolog.Ctx(ctx).With(). Str("handler", "on_update_channel"). diff --git a/pkg/connector/store/scoped_store.go b/pkg/connector/store/scoped_store.go index 9ddc8741..1472f511 100644 --- a/pkg/connector/store/scoped_store.go +++ b/pkg/connector/store/scoped_store.go @@ -74,19 +74,26 @@ const ( var _ updates.StateStorage = (*ScopedStore)(nil) +type channelIDPtsTuple struct { + ChannelID int64 + Pts int +} + +var ciptScanner = dbutil.ConvertRowFn[channelIDPtsTuple](func(row dbutil.Scannable) (cipt channelIDPtsTuple, err error) { + err = row.Scan(&cipt.ChannelID, &cipt.Pts) + return +}) + func (s *ScopedStore) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error { s.assertUserIDMatches(userID) - rows, err := s.db.Query(ctx, allChannelsQuery, userID) + items, err := ciptScanner.NewRowIter(s.db.Query(ctx, allChannelsQuery, userID)).AsList() if err != nil { return err } - var channelID int64 - var pts int - for rows.Next() { - if err = rows.Scan(&channelID, &pts); err != nil { - return err - } else if err = f(ctx, channelID, pts); err != nil { - return err + for _, item := range items { + err = f(ctx, item.ChannelID, item.Pts) + if err != nil { + return fmt.Errorf("iteration error for channel %d: %w", item.ChannelID, err) } } return nil diff --git a/pkg/gotd/telegram/updates/config.go b/pkg/gotd/telegram/updates/config.go index 7b0e0e2a..8a5d6906 100644 --- a/pkg/gotd/telegram/updates/config.go +++ b/pkg/gotd/telegram/updates/config.go @@ -16,6 +16,7 @@ type API interface { UpdatesGetState(ctx context.Context) (*tg.UpdatesState, error) UpdatesGetDifference(ctx context.Context, request *tg.UpdatesGetDifferenceRequest) (tg.UpdatesDifferenceClass, error) UpdatesGetChannelDifference(ctx context.Context, request *tg.UpdatesGetChannelDifferenceRequest) (tg.UpdatesChannelDifferenceClass, error) + ChannelsGetParticipant(ctx context.Context, request *tg.ChannelsGetParticipantRequest) (*tg.ChannelsChannelParticipant, error) } // Config of the manager. @@ -24,7 +25,8 @@ type Config struct { Handler telegram.UpdateHandler // Callback called if manager cannot // recover channel gap (optional). - OnChannelTooLong func(channelID int64) error + OnChannelTooLong func(channelID int64) error + OnNotChannelMember func(ctx context.Context, channelID int64) error // State storage. // In-mem used if not provided. Storage StateStorage diff --git a/pkg/gotd/telegram/updates/internal/e2e/server.go b/pkg/gotd/telegram/updates/internal/e2e/server.go index 00a57081..2c08a2cc 100644 --- a/pkg/gotd/telegram/updates/internal/e2e/server.go +++ b/pkg/gotd/telegram/updates/internal/e2e/server.go @@ -34,6 +34,12 @@ func newServer() *server { } } +func (s *server) ChannelsGetParticipant(ctx context.Context, request *tg.ChannelsGetParticipantRequest) (*tg.ChannelsChannelParticipant, error) { + return &tg.ChannelsChannelParticipant{ + Participant: &tg.ChannelParticipantSelf{}, + }, nil +} + // UpdatesGetState returns current remote state. func (s *server) UpdatesGetState(ctx context.Context) (*tg.UpdatesState, error) { s.mux.Lock() diff --git a/pkg/gotd/telegram/updates/manager.go b/pkg/gotd/telegram/updates/manager.go index 2003db49..813dc249 100644 --- a/pkg/gotd/telegram/updates/manager.go +++ b/pkg/gotd/telegram/updates/manager.go @@ -2,6 +2,7 @@ package updates import ( "context" + "fmt" "sync" "github.com/go-faster/errors" @@ -11,6 +12,7 @@ import ( "go.mau.fi/mautrix-telegram/pkg/gotd/telegram" "go.mau.fi/mautrix-telegram/pkg/gotd/tg" + "go.mau.fi/mautrix-telegram/pkg/gotd/tgerr" ) var _ telegram.UpdateHandler = (*Manager)(nil) @@ -74,6 +76,47 @@ type AuthOptions struct { OnStart func(ctx context.Context) } +type PtsAccessHashTuple struct { + Pts int + AccessHash int64 +} + +func (m *Manager) checkParticipant(ctx context.Context, api API, userID, channelID, hash int64) error { + lg := m.lg.With(zap.Int64("channel_id", channelID)) + lg.Info("Ensuring user is still in channel") + pcp, err := api.ChannelsGetParticipant(ctx, &tg.ChannelsGetParticipantRequest{ + Channel: &tg.InputChannel{ + ChannelID: channelID, + AccessHash: hash, + }, + Participant: &tg.InputPeerSelf{}, + }) + if err != nil { + if tgerr.Is(err, tg.ErrChannelInvalid, tg.ErrChannelPrivate, tg.ErrUserNotParticipant) { + lg.Warn("Removing update state for channel after error", zap.Error(err)) + } else { + lg.Error("channels.getParticipant failed", zap.Error(err)) + // TODO fatal error? + return nil + } + } else { + switch pcp.Participant.(type) { + case *tg.ChannelParticipantLeft, *tg.ChannelParticipantBanned: + lg.Warn("Removing update state for channel as user is left or banned") + default: + lg.Debug("Membership confirmed", zap.Any("participant", pcp.Participant)) + return nil + } + } + + if err := m.cfg.Storage.SetChannelPts(ctx, userID, channelID, -1); err != nil { + return fmt.Errorf("failed to clear pts: %w", err) + } else if err = m.cfg.OnNotChannelMember(ctx, channelID); err != nil { + return fmt.Errorf("OnNotChannelMember callback failed: %w", err) + } + return nil +} + // Run notifies manager about user authentication on the telegram server. // // If forget is true, local internalState (if exist) will be overwritten @@ -101,10 +144,7 @@ func (m *Manager) Run(ctx context.Context, api API, userID int64, opt AuthOption if err != nil { return errors.Wrap(err, "load internalState") } - channels := make(map[int64]struct { - Pts int - AccessHash int64 - }) + channels := make(map[int64]PtsAccessHashTuple) if err := m.cfg.Storage.ForEachChannels(ctx, userID, func(ctx context.Context, channelID int64, pts int) error { if pts == -1 { return nil @@ -112,16 +152,12 @@ func (m *Manager) Run(ctx context.Context, api API, userID int64, opt AuthOption hash, found, err := m.cfg.AccessHasher.GetChannelAccessHash(ctx, userID, channelID) if err != nil { return errors.Wrap(err, "get channel access hash") - } - - if !found { + } else if !found { return nil + } else if err = m.checkParticipant(ctx, api, userID, channelID, hash); err != nil { + return fmt.Errorf("failed to check if user is participant of channel %d: %w", channelID, err) } - - channels[channelID] = struct { - Pts int - AccessHash int64 - }{Pts: pts, AccessHash: hash} + channels[channelID] = PtsAccessHashTuple{Pts: pts, AccessHash: hash} return nil }); err != nil { return errors.Wrap(err, "iterate channels") diff --git a/pkg/gotd/telegram/updates/state.go b/pkg/gotd/telegram/updates/state.go index f9268762..077ab035 100644 --- a/pkg/gotd/telegram/updates/state.go +++ b/pkg/gotd/telegram/updates/state.go @@ -62,11 +62,8 @@ type internalState struct { } type stateConfig struct { - State State - Channels map[int64]struct { - Pts int - AccessHash int64 - } + State State + Channels map[int64]PtsAccessHashTuple RawClient API Logger *zap.Logger Tracer trace.Tracer @@ -422,6 +419,10 @@ func (s *internalState) createAndRunChannelState(ctx context.Context, channelID, s.channelsLock.Unlock() s.log.Info("Removed channel state due to error", zap.Int64("channel_id", channelID), zap.Error(err)) return nil + } else if ctx.Err() == nil { + s.log.Error("Channel state stopped with unexpected error, new messages may stop arriving", + zap.Int64("channel_id", channelID), zap.Error(err)) + return nil } return err }) diff --git a/pkg/gotd/telegram/updates/state_channel.go b/pkg/gotd/telegram/updates/state_channel.go index eb025aa5..989b58ca 100644 --- a/pkg/gotd/telegram/updates/state_channel.go +++ b/pkg/gotd/telegram/updates/state_channel.go @@ -205,12 +205,12 @@ func (s *channelState) applyPts(ctx context.Context, state int, updates []update Users: ents.Users, Chats: ents.Chats, }); err != nil { - s.log.Error("Handle update error", zap.Error(err)) + s.log.Error("Handle update error (applyPts)", zap.Error(err)) return nil } if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, state); err != nil { - s.log.Error("SetChannelPts error", zap.Error(err)) + s.log.Error("SetChannelPts error (applyPts)", zap.Error(err)) } return nil @@ -297,7 +297,7 @@ func (s *channelState) getDifference(ctx context.Context) error { } if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, diff.Pts); err != nil { - s.log.Warn("SetChannelPts error", zap.Error(err)) + s.log.Warn("SetChannelPts error (getDifference)", zap.Error(err)) } s.pts.SetState(diff.Pts, "updates.channelDifference") @@ -313,7 +313,7 @@ func (s *channelState) getDifference(ctx context.Context) error { case *tg.UpdatesChannelDifferenceEmpty: if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, diff.Pts); err != nil { - s.log.Warn("SetChannelPts error", zap.Error(err)) + s.log.Warn("SetChannelPts error (getDifference empty)", zap.Error(err)) } s.pts.SetState(diff.Pts, "updates.channelDifferenceEmpty") @@ -333,7 +333,7 @@ func (s *channelState) getDifference(ctx context.Context) error { s.log.Warn("UpdatesChannelDifferenceTooLong invalid Dialog", zap.Error(err)) } else { if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, remotePts); err != nil { - s.log.Warn("SetChannelPts error", zap.Error(err)) + s.log.Warn("SetChannelPts error (getDifference too long)", zap.Error(err)) } s.pts.SetState(remotePts, "updates.channelDifferenceTooLong dialog new pts")