Add option to sync portals in backfill queue

This commit is contained in:
Tulir Asokan
2022-10-14 13:55:12 +03:00
parent af2f20f7b2
commit 0bbf64d240
10 changed files with 315 additions and 116 deletions

View File

@@ -620,7 +620,7 @@ class AbstractUser(ABC):
self.log.info(
"Creating Matrix room with data fetched by Telethon due to UpdateChannel"
)
await portal.create_matrix_room(self, chan)
await portal.create_matrix_room(self, chan, invites=[self.mxid])
async def update_message(self, original_update: UpdateMessage) -> None:
update, sender, portal = await self.get_message_details(original_update)

View File

@@ -113,6 +113,7 @@ class Config(BaseBridgeConfig):
else:
copy("bridge.sync_update_limit")
copy("bridge.sync_create_limit")
copy("bridge.sync_deferred_create_all")
copy("bridge.sync_direct_chats")
copy("bridge.max_telegram_delete")
copy("bridge.sync_matrix_state")

View File

@@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.async_db import Database
from .backfill_queue import Backfill
from .backfill_queue import Backfill, BackfillType
from .bot_chat import BotChat
from .disappearing_message import DisappearingMessage
from .message import Message

View File

@@ -15,8 +15,10 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar
from datetime import datetime, timedelta
from enum import Enum
import json
from asyncpg import Record
from attr import dataclass
@@ -29,6 +31,11 @@ from ..types import TelegramID
fake_db = Database.create("") if TYPE_CHECKING else None
class BackfillType(Enum):
HISTORICAL = "historical"
SYNC_DIALOG = "sync_dialog"
@dataclass
class Backfill:
db: ClassVar[Database] = fake_db
@@ -36,9 +43,11 @@ class Backfill:
queue_id: int | None
user_mxid: UserID
priority: int
type: BackfillType
portal_tgid: TelegramID
portal_tg_receiver: TelegramID
anchor_msg_id: TelegramID | None
extra_data: dict[str, Any]
messages_per_batch: int
post_batch_delay: int
max_batches: int
@@ -50,10 +59,12 @@ class Backfill:
def new(
user_mxid: UserID,
priority: int,
type: BackfillType,
portal_tgid: TelegramID,
portal_tg_receiver: TelegramID,
messages_per_batch: int,
anchor_msg_id: TelegramID | None = None,
extra_data: dict[str, Any] | None = None,
post_batch_delay: int = 0,
max_batches: int = -1,
) -> "Backfill":
@@ -61,9 +72,11 @@ class Backfill:
queue_id=None,
user_mxid=user_mxid,
priority=priority,
type=type,
portal_tgid=portal_tgid,
portal_tg_receiver=portal_tg_receiver,
anchor_msg_id=anchor_msg_id,
extra_data=extra_data or {},
messages_per_batch=messages_per_batch,
post_batch_delay=post_batch_delay,
max_batches=max_batches,
@@ -76,14 +89,19 @@ class Backfill:
def _from_row(cls, row: Record | None) -> Backfill | None:
if row is None:
return None
return cls(**row)
data = {**row}
type = BackfillType(data.pop("type"))
extra_data = json.loads(data.pop("extra_data", None) or "{}")
return cls(**data, type=type, extra_data=extra_data)
columns = [
"user_mxid",
"priority",
"type",
"portal_tgid",
"portal_tg_receiver",
"anchor_msg_id",
"extra_data",
"messages_per_batch",
"post_batch_delay",
"max_batches",
@@ -118,22 +136,37 @@ class Backfill:
)
@classmethod
async def get(
async def delete_existing(
cls,
user_mxid: UserID,
portal_tgid: int,
portal_tg_receiver: int,
type: BackfillType,
) -> Backfill | None:
q = f"""
SELECT queue_id, {cls.columns_str}
FROM backfill_queue
WHERE user_mxid=$1
AND portal_tgid=$2
AND portal_tg_receiver=$3
ORDER BY priority, queue_id
LIMIT 1
WITH deleted_entries AS (
DELETE FROM backfill_queue
WHERE user_mxid=$1
AND portal_tgid=$2
AND portal_tg_receiver=$3
AND type=$4
AND dispatch_time IS NULL
AND completed_at IS NULL
RETURNING 1
)
WITH dispatched_entries AS (
SELECT 1 FROM backfill_queue
WHERE user_mxid=$1
AND portal_tgid=$2
AND portal_tg_receiver=$3
AND type=$4
AND dispatch_time IS NOT NULL
AND completed_at IS NULL
)
"""
return cls._from_row(await cls.db.fetchrow(q, user_mxid, portal_tgid, portal_tg_receiver))
return cls._from_row(
await cls.db.fetchrow(q, user_mxid, portal_tgid, portal_tg_receiver, type.value)
)
@classmethod
async def delete_all(cls, user_mxid: UserID) -> None:
@@ -144,27 +177,47 @@ class Backfill:
q = "DELETE FROM backfill_queue WHERE portal_tgid=$1 AND portal_tg_receiver=$2"
await cls.db.execute(q, tgid, tg_receiver)
async def insert(self) -> None:
async def insert(self) -> list[Backfill]:
delete_q = f"""
DELETE FROM backfill_queue
WHERE user_mxid=$1
AND portal_tgid=$2
AND portal_tg_receiver=$3
AND type=$4
AND dispatch_time IS NULL
AND completed_at IS NULL
RETURNING {self.columns_str}
"""
q = f"""
INSERT INTO backfill_queue ({self.columns_str})
VALUES ({','.join(f'${i+1}' for i in range(len(self.columns)))})
RETURNING queue_id
"""
row = await self.db.fetchrow(
q,
self.user_mxid,
self.priority,
self.portal_tgid,
self.portal_tg_receiver,
self.anchor_msg_id,
self.messages_per_batch,
self.post_batch_delay,
self.max_batches,
self.dispatch_time,
self.completed_at,
self.cooldown_timeout,
)
self.queue_id = row["queue_id"]
async with self.db.acquire() as conn, conn.transaction():
deleted_rows = await self.db.fetch(
delete_q,
self.user_mxid,
self.portal_tgid,
self.portal_tg_receiver,
self.type.value,
)
self.queue_id = await self.db.fetchval(
q,
self.user_mxid,
self.priority,
self.type.value,
self.portal_tgid,
self.portal_tg_receiver,
self.anchor_msg_id,
json.dumps(self.extra_data) if self.extra_data else None,
self.messages_per_batch,
self.post_batch_delay,
self.max_batches,
self.dispatch_time,
self.completed_at,
self.cooldown_timeout,
)
return [self._from_row(row) for row in deleted_rows]
async def mark_dispatched(self) -> None:
q = "UPDATE backfill_queue SET dispatch_time=$1 WHERE queue_id=$2"

View File

@@ -18,4 +18,5 @@ from . import (
v13_multiple_reactions,
v14_puppet_custom_mxid_index,
v15_backfill_anchor_id,
v16_backfill_type,
)

View File

@@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.async_db import Connection, Scheme
latest_version = 15
latest_version = 16
async def create_latest_tables(conn: Connection, scheme: Scheme) -> int:
@@ -219,9 +219,11 @@ async def create_latest_tables(conn: Connection, scheme: Scheme) -> int:
queue_id INTEGER PRIMARY KEY {gen},
user_mxid TEXT,
priority INTEGER NOT NULL,
type TEXT NOT NULL,
portal_tgid BIGINT,
portal_tg_receiver BIGINT,
anchor_msg_id BIGINT,
extra_data jsonb,
messages_per_batch INTEGER NOT NULL,
post_batch_delay INTEGER NOT NULL,
max_batches INTEGER NOT NULL,

View File

@@ -0,0 +1,28 @@
# mautrix-telegram - A Matrix-Telegram puppeting bridge
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.async_db import Connection, Scheme
from . import upgrade_table
@upgrade_table.register(description="Add type for backfill queue items")
async def upgrade_v16(conn: Connection, scheme: Scheme) -> None:
await conn.execute(
"ALTER TABLE backfill_queue ADD COLUMN type TEXT NOT NULL DEFAULT 'historical'"
)
await conn.execute("ALTER TABLE backfill_queue ADD COLUMN extra_data jsonb")
if scheme != Scheme.SQLITE:
await conn.execute("ALTER TABLE backfill_queue ALTER COLUMN type DROP DEFAULT")

View File

@@ -169,7 +169,10 @@ bridge:
sync_update_limit: 0
# Number of most recently active dialogs to create portals for when syncing chats.
# Set to 0 to remove limit.
sync_create_limit: 30
sync_create_limit: 15
# Should all chats be scheduled to be created later?
# This is best used in combination with MSC2716 infinite backfill.
sync_deferred_create_all: false
# Whether or not to sync and create portals for direct chats at startup.
sync_direct_chats: false
# The maximum number of simultaneous Telegram deletions to handle.

View File

@@ -187,6 +187,7 @@ from . import (
from .config import Config
from .db import (
Backfill,
BackfillType,
DisappearingMessage,
Message as DBMessage,
Portal as DBPortal,
@@ -257,6 +258,7 @@ class Portal(DBPortal, BasePortal):
backfill_method_lock: asyncio.Lock
backfill_leave: set[IntentAPI] | None
backfill_msc2716: bool
backfill_enable: bool
alias: RoomAlias | None
@@ -439,6 +441,7 @@ class Portal(DBPortal, BasePortal):
cls.filter_list = cls.config["bridge.filter.list"]
cls.hs_domain = cls.config["homeserver.domain"]
cls.backfill_msc2716 = cls.config["bridge.backfill.msc2716"]
cls.backfill_enable = cls.config["bridge.backfill.enable"]
cls.alias_template = SimpleTemplate(
cls.config["bridge.alias_template"],
"groupname",
@@ -645,9 +648,10 @@ class Portal(DBPortal, BasePortal):
puppet: p.Puppet = None,
levels: PowerLevelStateEventContent = None,
users: list[User] = None,
client: MautrixTelegramClient | None = None,
) -> None:
try:
await self._update_matrix_room(user, entity, puppet, levels, users)
await self._update_matrix_room(user, entity, puppet, levels, users, client)
except Exception:
self.log.exception("Fatal error updating Matrix room")
@@ -658,12 +662,15 @@ class Portal(DBPortal, BasePortal):
puppet: p.Puppet = None,
levels: PowerLevelStateEventContent = None,
users: list[User] = None,
client: MautrixTelegramClient | None = None,
) -> None:
if not client:
client = user.client
if not self.is_direct:
await self.update_info(user, entity)
await self.update_info(user, entity, client=client)
if not users:
users = await self._get_users(user, entity)
await self._sync_telegram_users(user, users)
users = await self._get_users(client, entity)
await self._sync_telegram_users(user, users, client=client)
await self.update_power_levels(users, levels)
else:
if not puppet:
@@ -708,12 +715,13 @@ class Portal(DBPortal, BasePortal):
entity: TypeChat | User = None,
invites: InviteList = None,
update_if_exists: bool = True,
client: MautrixTelegramClient | None = None,
) -> RoomID | None:
if self.mxid:
if update_if_exists:
if not entity:
try:
entity = await self.get_entity(user)
entity = await self.get_entity(user, client)
except Exception:
self.log.exception(f"Failed to get entity through {user.tgid} for update")
return self.mxid
@@ -723,7 +731,7 @@ class Portal(DBPortal, BasePortal):
return self.mxid
async with self._room_create_lock:
try:
return await self._create_matrix_room(user, entity, invites)
return await self._create_matrix_room(user, entity, invites, client=client)
except Exception:
self.log.exception("Fatal error creating Matrix room")
@@ -774,17 +782,23 @@ class Portal(DBPortal, BasePortal):
self.log.warning("Failed to update bridge info", exc_info=True)
async def _create_matrix_room(
self, user: au.AbstractUser, entity: TypeChat | User, invites: InviteList
self,
user: au.AbstractUser,
entity: TypeChat | User,
invites: InviteList,
client: MautrixTelegramClient | None = None,
) -> RoomID | None:
if self.mxid:
return self.mxid
elif not self.allow_bridging:
return None
if not client:
client = user.client
invites = invites or []
if not entity:
entity = await self.get_entity(user)
entity = await self.get_entity(user, client)
self.log.trace("Fetched data: %s", entity)
participants_count = 2
@@ -794,17 +808,17 @@ class Portal(DBPortal, BasePortal):
participants_count = entity.participants_count
if participants_count is None and self.config["bridge.max_member_count"] > 0:
self.log.warning(f"Participant count not found in entity, fetching manually")
participants_count = (await user.client.get_participants(entity, limit=0)).total
participants_count = (await client.get_participants(entity, limit=0)).total
if participants_count and 0 < self.config["bridge.max_member_count"] < participants_count:
self.log.warning(f"Not bridging chat, too many participants (%d)", participants_count)
self._bridging_blocked_at_runtime = True
return None
self.log.debug("Creating room")
self.log.debug("Preparing to create room")
if self.is_direct:
puppet = await self.get_dm_puppet()
await puppet.update_info(user, entity)
await puppet.update_info(user, entity, client_override=client)
self._main_intent = puppet.intent_for(self)
if self.tgid == user.tgid:
self.title = "Telegram Saved Messages"
@@ -812,7 +826,7 @@ class Portal(DBPortal, BasePortal):
else:
puppet = None
self._main_intent = self.az.intent
await self.update_info(user, entity)
await self.update_info(user, entity, client=client)
preset = RoomCreatePreset.PRIVATE
if self.peer_type == "channel" and entity.username:
@@ -831,7 +845,7 @@ class Portal(DBPortal, BasePortal):
power_levels = putil.get_base_power_levels(self, entity=entity)
users = None
if not self.is_direct:
users = await self._get_users(user, entity)
users = await self._get_users(client, entity)
if self.has_bot:
extra_invites = self.config["bridge.relaybot.group_chat_invite"]
invites += extra_invites
@@ -840,7 +854,7 @@ class Portal(DBPortal, BasePortal):
await putil.participants_to_power_levels(self, users, power_levels)
elif self.bot and self.tg_receiver == self.bot.tgid:
assert puppet is not None
invites = self.config["bridge.relaybot.private_chat.invite"]
invites += self.config["bridge.relaybot.private_chat.invite"]
for invite in invites:
power_levels.users.setdefault(invite, 100)
self.title = puppet.displayname
@@ -865,10 +879,10 @@ class Portal(DBPortal, BasePortal):
autojoin_invites = self.bridge.homeserver_software.is_hungry
create_invites = set()
if autojoin_invites:
invites = []
create_invites |= set(invites)
invites = []
if not self.is_direct:
create_invites |= await self._sync_telegram_users(user, users)
create_invites |= await self._sync_telegram_users(user, users, client=client)
if self.config["bridge.encryption.default"] and self.matrix.e2ee:
self.encrypted = True
initial_state.append(
@@ -896,6 +910,11 @@ class Portal(DBPortal, BasePortal):
)
with self.backfill_lock:
self.log.debug(
f"Creating room with parameters invite={create_invites}, {autojoin_invites=}, "
f"{preset=}, {alias=!r}, name={self.title!r}, topic={self.about!r}, "
f"{creation_content=}, is_direct={self.is_direct}"
)
room_id = await self.main_intent.create_room(
alias_localpart=alias,
preset=preset,
@@ -912,7 +931,7 @@ class Portal(DBPortal, BasePortal):
self.name_set = bool(self.title)
self.avatar_set = bool(self.avatar_url)
if self.encrypted and self.matrix.e2ee and self.is_direct:
if not autojoin_invites and self.encrypted and self.matrix.e2ee and self.is_direct:
try:
await self.az.intent.ensure_joined(room_id)
except Exception:
@@ -928,7 +947,7 @@ class Portal(DBPortal, BasePortal):
if not autojoin_invites or not self.is_direct:
await self.invite_to_matrix(invites)
await self.update_matrix_room(
user, entity, puppet, levels=power_levels, users=users
user, entity, puppet, levels=power_levels, users=users, client=client
)
else:
# When using autojoining, all metadata is already set, so just update state caches
@@ -943,9 +962,9 @@ class Portal(DBPortal, BasePortal):
)
await self.save()
if isinstance(user, u.User) or not self.backfill_msc2716:
if self.backfill_enable and (isinstance(user, u.User) or not self.backfill_msc2716):
try:
await self.forward_backfill(user, initial=True)
await self.forward_backfill(user, initial=True, client=client)
except Exception:
self.log.exception("Error in initial backfill")
if self.backfill_msc2716:
@@ -955,7 +974,7 @@ class Portal(DBPortal, BasePortal):
async def _get_users(
self,
user: au.AbstractUser,
client: MautrixTelegramClient,
entity: TypeInputPeer | InputUser | TypeChat | TypeUser | InputChannel,
) -> list[TypeUser]:
if self.peer_type == "channel" and not self.megagroup and not self.sync_channel_members:
@@ -963,7 +982,7 @@ class Portal(DBPortal, BasePortal):
limit = self.max_initial_member_sync
if limit == 0:
return []
return await putil.get_users(user.client, self.tgid, entity, limit, self.peer_type)
return await putil.get_users(client, self.tgid, entity, limit, self.peer_type)
async def update_power_levels(
self,
@@ -985,7 +1004,10 @@ class Portal(DBPortal, BasePortal):
await user.register_portal(self)
async def _sync_telegram_users(
self, source: au.AbstractUser, users: list[User]
self,
source: au.AbstractUser,
users: list[User],
client: MautrixTelegramClient | None = None,
) -> set[UserID] | None:
allowed_tgids = set()
join_mxids = set()
@@ -996,7 +1018,7 @@ class Portal(DBPortal, BasePortal):
await self._add_bot_chat(entity)
allowed_tgids.add(entity.id)
await puppet.update_info(source, entity)
await puppet.update_info(source, entity, client_override=client)
if skip_deleted and entity.deleted:
continue
@@ -1122,7 +1144,12 @@ class Portal(DBPortal, BasePortal):
except MForbidden as e:
self.log.warning(f"Failed to kick {user.mxid}: {e}")
async def update_info(self, user: au.AbstractUser, entity: TypeChat = None) -> None:
async def update_info(
self,
user: au.AbstractUser,
entity: TypeChat = None,
client: MautrixTelegramClient | None = None,
) -> None:
if self.peer_type == "user":
self.log.warning("Called update_info() for direct chat portal")
return
@@ -1131,7 +1158,7 @@ class Portal(DBPortal, BasePortal):
self.log.debug("Updating info")
try:
if not entity:
entity = await self.get_entity(user)
entity = await self.get_entity(user, client)
self.log.trace("Fetched data: %s", entity)
if self.peer_type == "channel":
@@ -1145,7 +1172,7 @@ class Portal(DBPortal, BasePortal):
changed = await self._update_title(entity.title) or changed
if isinstance(entity.photo, ChatPhoto):
changed = await self._update_avatar(user, entity.photo) or changed
changed = await self._update_avatar(user, entity.photo, client=client) or changed
except Exception:
self.log.exception(f"Failed to update info from source {user.tgid}")
@@ -1254,6 +1281,7 @@ class Portal(DBPortal, BasePortal):
photo: TypeChatPhoto | TypeUserProfilePhoto,
sender: p.Puppet | None = None,
save: bool = False,
client: MautrixTelegramClient | None = None,
) -> bool:
if isinstance(photo, (ChatPhoto, UserProfilePhoto)):
loc = InputPeerPhotoFileLocation(
@@ -1280,7 +1308,7 @@ class Portal(DBPortal, BasePortal):
self.avatar_url = None
elif self.photo_id != photo_id or not self.avatar_url:
file = await util.transfer_file_to_matrix(
user.client,
client or user.client,
self.main_intent,
loc,
async_upload=self.config["homeserver.async_media"],
@@ -2649,21 +2677,28 @@ class Portal(DBPortal, BasePortal):
max_batches: int | None = None,
messages_per_batch: int | None = None,
anchor_msg_id: int | None = None,
extra_data: dict[str, Any] | None = None,
type: BackfillType = BackfillType.HISTORICAL,
) -> None:
# TODO check that there are no queued backfills
# if not await Backfill.get(source.mxid, self.tgid, self.tg_receiver):
await Backfill.new(
new_backfill = Backfill.new(
user_mxid=source.mxid,
priority=priority,
type=type,
portal_tgid=self.tgid,
portal_tg_receiver=self.tg_receiver,
anchor_msg_id=anchor_msg_id,
extra_data=extra_data,
messages_per_batch=(
messages_per_batch or self.config["bridge.backfill.incremental.messages_per_batch"]
),
post_batch_delay=self.config["bridge.backfill.incremental.post_batch_delay"],
max_batches=max_batches or self._default_max_batches,
).insert()
)
deleted_entries = await new_backfill.insert()
if deleted_entries:
self.log.debug(
"Deleted backfill queue entries while inserting new item: %s", deleted_entries
)
source.wakeup_backfill_task.set()
async def forward_backfill(
@@ -2672,14 +2707,17 @@ class Portal(DBPortal, BasePortal):
initial: bool,
last_tgid: int | None = None,
override_limit: int | None = None,
client: MautrixTelegramClient | None = None,
) -> str:
if not client:
client = source.client
type = "initial" if initial else "sync"
limit = override_limit or self.config[f"bridge.backfill.forward.{type}_limit"]
if limit == 0:
return "Limit is zero, not backfilling"
with self.backfill_lock:
output = await self.backfill(
source, source.client, forward=True, forward_limit=limit, last_tgid=last_tgid
source, client, forward=True, forward_limit=limit, last_tgid=last_tgid
)
self.log.debug(f"Forward backfill complete, status: {output}")
return output
@@ -2693,6 +2731,8 @@ class Portal(DBPortal, BasePortal):
forward_limit: int | None = None,
last_tgid: int | None = None,
) -> str:
if not self.backfill_enable:
return "Backfilling is disabled in the bridge config"
async with self.backfill_method_lock:
return await self._locked_backfill(
source, client, req, forward, forward_limit, last_tgid
@@ -2778,7 +2818,7 @@ class Portal(DBPortal, BasePortal):
self.log.debug(f"Enqueuing more backfill through {source.mxid}")
await self.enqueue_backfill(
source,
priority=100,
priority=max(100, req.priority + 1),
messages_per_batch=req.messages_per_batch,
max_batches=-1 if req.max_batches < 0 else (req.max_batches - 1),
anchor_msg_id=lowest_id,
@@ -3515,9 +3555,13 @@ class Portal(DBPortal, BasePortal):
) -> Awaitable[TypeInputPeer | TypeInputChannel]:
return user.client.get_input_entity(self.peer)
async def get_entity(self, user: au.AbstractUser) -> TypeChat:
async def get_entity(
self, user: au.AbstractUser, client: MautrixTelegramClient | None = None
) -> TypeChat:
if not client:
client = user.client
try:
return await user.client.get_entity(self.peer)
return await client.get_entity(self.peer)
except ValueError:
if user.is_bot:
self.log.warning(f"Could not find entity with bot {user.tgid}. Failing...")
@@ -3525,7 +3569,7 @@ class Portal(DBPortal, BasePortal):
self.log.warning(
f"Could not find entity with user {user.tgid}. falling back to get_dialogs."
)
async for dialog in user.client.iter_dialogs():
async for dialog in client.iter_dialogs():
if dialog.entity.id == self.tgid:
return dialog.entity
raise

View File

@@ -16,7 +16,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, NamedTuple, cast
from datetime import datetime, timedelta, timezone
from datetime import datetime
import asyncio
import time
@@ -39,6 +39,7 @@ from telethon.tl.types import (
InputUserSelf,
NotifyPeer,
PeerUser,
TypeChat,
TypeUpdate,
UpdateFolderPeers,
UpdateNewChannelMessage,
@@ -62,7 +63,7 @@ from mautrix.util.opt_prometheus import Gauge
from . import portal as po, puppet as pu, util
from .abstract_user import AbstractUser
from .db import Backfill, Message as DBMessage, PgSession, User as DBUser
from .db import Backfill, BackfillType, Message as DBMessage, PgSession, User as DBUser
from .tgclient import MautrixTelegramClient
from .types import TelegramID
@@ -347,7 +348,7 @@ class User(DBUser, AbstractUser, BaseUser):
self._track_metric(METRIC_LOGGED_IN, True)
if not self._backfill_task or self._backfill_task.done():
self._backfill_task = asyncio.create_task(self._handle_backfill_requests_loop())
self._backfill_task = asyncio.create_task(self._try_handle_backfill_requests_loop())
try:
puppet = await pu.Puppet.get_by_tgid(self.tgid)
@@ -378,6 +379,14 @@ class User(DBUser, AbstractUser, BaseUser):
"max_file_size": min(self.bridge.matrix.media_config.upload_size, 2000 * 1024 * 1024),
}
async def _try_handle_backfill_requests_loop(self) -> None:
if not self.config["bridge.backfill.enable"]:
return
try:
await self._handle_backfill_requests_loop()
except Exception:
self.log.exception("Fatal error in backfill request loop")
async def _handle_backfill_requests_loop(self) -> None:
while True:
req = await Backfill.get_next(self.mxid)
@@ -388,7 +397,11 @@ class User(DBUser, AbstractUser, BaseUser):
pass
self.wakeup_backfill_task.clear()
else:
await self._takeout_and_backfill(req)
try:
await self._takeout_and_backfill(req)
except Exception:
self.log.exception("Error in takeout backfill loop, retrying in an hour")
await asyncio.sleep(3600)
async def _takeout_and_backfill(self, first_req: Backfill, first_attempt: bool = True) -> None:
self.takeout_retry_immediate.clear()
@@ -437,13 +450,33 @@ class User(DBUser, AbstractUser, BaseUser):
TelegramID(req.portal_tgid), tg_receiver=TelegramID(req.portal_tg_receiver)
)
await req.mark_dispatched()
await portal.backfill(self, client, req=req)
if req.type == BackfillType.HISTORICAL:
await portal.backfill(self, client, req=req)
elif req.type == BackfillType.SYNC_DIALOG:
await self._backfill_sync_dialog(portal, client, req.extra_data)
await req.mark_done()
await asyncio.sleep(req.post_batch_delay)
except Exception:
self.log.exception("Error handling backfill request for %s", req.portal_tgid)
await req.set_cooldown_timeout(1800)
async def _backfill_sync_dialog(
self, portal: po.Portal, client: MautrixTelegramClient, post_sync_args: dict[str, Any]
) -> None:
if portal.mxid:
self.log.debug("Portal already exists, skipping dialog sync backfill queue item")
return
self.log.info(f"Creating portal for {portal.tgid_log} as part of backfill loop")
try:
await portal.create_matrix_room(
self, client=client, update_if_exists=False, invites=[self.mxid]
)
except Exception:
self.log.exception(f"Error while creating {portal.tgid_log}")
else:
puppet = await pu.Puppet.get_by_custom_mxid(self.mxid)
await self._post_sync_dialog(portal, puppet, was_created=True, **post_sync_args)
async def update(self, update: TypeUpdate) -> bool:
if not self.is_bot:
return False
@@ -604,19 +637,18 @@ class User(DBUser, AbstractUser, BaseUser):
if active and tag_info is None:
tag_info = RoomTagInfo(order=0.5)
tag_info[DOUBLE_PUPPET_SOURCE_KEY] = self.bridge.name
self.log.debug("Adding tag {tag} to {portal.mxid}/{portal.tgid}")
self.log.debug(f"Adding tag {tag} to {portal.mxid}/{portal.tgid}")
await puppet.intent.set_room_tag(portal.mxid, tag, tag_info)
elif (
not active and tag_info and tag_info.get(DOUBLE_PUPPET_SOURCE_KEY) == self.bridge.name
):
self.log.debug("Removing tag {tag} from {portal.mxid}/{portal.tgid}")
self.log.debug(f"Removing tag {tag} from {portal.mxid}/{portal.tgid}")
await puppet.intent.remove_room_tag(portal.mxid, tag)
async def _mute_room(self, puppet: pu.Puppet, portal: po.Portal, mute_until: datetime) -> None:
async def _mute_room(self, puppet: pu.Puppet, portal: po.Portal, mute_until: float) -> None:
if not self.config["bridge.mute_bridging"] or not portal or not portal.mxid:
return
now = datetime.utcnow().replace(tzinfo=timezone.utc)
if mute_until is not None and mute_until > now:
if mute_until is not None and mute_until > time.time():
self.log.debug(
f"Muting {portal.mxid}/{portal.tgid} (muted until {mute_until} on Telegram)"
)
@@ -672,12 +704,24 @@ class User(DBUser, AbstractUser, BaseUser):
portal = await po.Portal.get_by_entity(
update.peer.peer, tg_receiver=self.tgid, create=False
)
await self._mute_room(puppet, portal, update.notify_settings.mute_until)
await self._mute_room(puppet, portal, update.notify_settings.mute_until.timestamp())
async def _sync_dialog(
self, portal: po.Portal, dialog: Dialog, should_create: bool, puppet: pu.Puppet | None
) -> None:
was_created = False
post_sync_args = {
"last_message_ts": cast(datetime, dialog.date).timestamp(),
"unread_count": dialog.unread_count,
"max_read_id": dialog.dialog.read_inbox_max_id,
"mute_until": (
dialog.dialog.notify_settings.mute_until.timestamp()
if dialog.dialog.notify_settings.mute_until
else None
),
"pinned": dialog.pinned,
"archived": dialog.archived,
}
if portal.mxid:
try:
await portal.forward_backfill(self, initial=False, last_tgid=dialog.message.id)
@@ -693,41 +737,65 @@ class User(DBUser, AbstractUser, BaseUser):
was_created = True
except Exception:
self.log.exception(f"Error while creating {portal.tgid_log}")
if portal.mxid and puppet and puppet.is_real_user:
tg_space = portal.tgid if portal.peer_type == "channel" else self.tgid
last_message_date: float = cast(datetime, dialog.date).timestamp()
unread_threshold_hours = self.config["bridge.backfill.unread_hours_threshold"]
force_read = (
was_created
and unread_threshold_hours >= 0
and last_message_date + (unread_threshold_hours * 60 * 60) < time.time()
elif self.config["bridge.sync_deferred_create_all"]:
await portal.enqueue_backfill(
self,
priority=40,
type=BackfillType.SYNC_DIALOG,
extra_data=post_sync_args,
)
if dialog.unread_count == 0 or force_read:
# This is usually more reliable than finding a specific message
# e.g. if the last read message is a service message that isn't in the message db
last_read = await DBMessage.find_last(portal.mxid, tg_space)
if force_read:
self.log.debug(
f"Marking {portal.tgid_log} as read because the last message is from "
f"{dialog.date} (unread threshold is {unread_threshold_hours} hours)"
)
else:
last_read = await DBMessage.get_one_by_tgid(
portal.tgid, tg_space, dialog.dialog.read_inbox_max_id
if portal.mxid and puppet and puppet.is_real_user:
await self._post_sync_dialog(
portal=portal,
puppet=puppet,
was_created=was_created,
**post_sync_args,
)
async def _post_sync_dialog(
self,
portal: po.Portal,
puppet: pu.Puppet,
was_created: bool,
max_read_id: int,
last_message_ts: float,
unread_count: int,
mute_until: float,
pinned: bool,
archived: bool,
) -> None:
self.log.debug(
f"Running dialog post-sync for {portal.tgid_log} with args "
f"{was_created=}, {max_read_id=}, {last_message_ts=}, {unread_count=}, "
f"{mute_until=}, {pinned=}, {archived=}"
)
tg_space = portal.tgid if portal.peer_type == "channel" else self.tgid
unread_threshold_hours = self.config["bridge.backfill.unread_hours_threshold"]
force_read = (
was_created
and unread_threshold_hours >= 0
and last_message_ts + (unread_threshold_hours * 60 * 60) < time.time()
)
if unread_count == 0 or force_read:
# This is usually more reliable than finding a specific message
# e.g. if the last read message is a service message that isn't in the message db
last_read = await DBMessage.find_last(portal.mxid, tg_space)
if force_read:
self.log.debug(
f"Marking {portal.tgid_log} as read because the last message is from "
f"{last_message_ts} (unread threshold is {unread_threshold_hours} hours)"
)
try:
if last_read:
await puppet.intent.mark_read(last_read.mx_room, last_read.mxid)
if was_created or not self.config["bridge.tag_only_on_create"]:
await self._mute_room(puppet, portal, dialog.dialog.notify_settings.mute_until)
await self._tag_room(
puppet, portal, self.config["bridge.pinned_tag"], dialog.pinned
)
await self._tag_room(
puppet, portal, self.config["bridge.archive_tag"], dialog.archived
)
except Exception:
self.log.exception(f"Error updating read status and tags for {portal.tgid_log}")
else:
last_read = await DBMessage.get_one_by_tgid(portal.tgid, tg_space, max_read_id)
try:
if last_read:
await puppet.intent.mark_read(last_read.mx_room, last_read.mxid)
if was_created or not self.config["bridge.tag_only_on_create"]:
await self._mute_room(puppet, portal, mute_until)
await self._tag_room(puppet, portal, self.config["bridge.pinned_tag"], pinned)
await self._tag_room(puppet, portal, self.config["bridge.archive_tag"], archived)
except Exception:
self.log.exception(f"Error updating read status and tags for {portal.tgid_log}")
async def get_cached_portals(self) -> dict[tuple[TelegramID, TelegramID], po.Portal]:
if self._portals_cache is None:
@@ -744,9 +812,7 @@ class User(DBUser, AbstractUser, BaseUser):
update_limit = self.config["bridge.sync_update_limit"] or None
create_limit = self.config["bridge.sync_create_limit"]
index = 0
self.log.debug(
f"Syncing dialogs (update_limit={update_limit}, create_limit={create_limit})"
)
self.log.debug(f"Syncing dialogs ({update_limit=}, {create_limit=})")
await self.push_bridge_state(BridgeStateEvent.BACKFILLING)
puppet = await pu.Puppet.get_by_custom_mxid(self.mxid)
dialog: Dialog
@@ -767,11 +833,12 @@ class User(DBUser, AbstractUser, BaseUser):
continue
portal = await po.Portal.get_by_entity(entity, tg_receiver=self.tgid)
new_portal_cache[portal.tgid_full] = portal
should_create = not create_limit or index < create_limit
coro = self._sync_dialog(
portal=portal,
dialog=dialog,
puppet=puppet,
should_create=not create_limit or index < create_limit,
should_create=should_create,
)
creators.append(asyncio.create_task(coro))
index += 1