# mautrix-telegram - A Matrix-Telegram puppeting bridge # Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from __future__ import annotations from typing import AsyncGenerator, Awaitable, Union, cast from collections import defaultdict import asyncio import hashlib import logging import math import time from aiohttp import ClientResponse from telethon import helpers, utils from telethon.crypto import AuthKey from telethon.network import MTProtoSender from telethon.tl.alltlobjects import LAYER from telethon.tl.functions import InvokeWithLayerRequest from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest from telethon.tl.functions.upload import ( GetFileRequest, SaveBigFilePartRequest, SaveFilePartRequest, ) from telethon.tl.types import ( Document, InputDocumentFileLocation, InputFile, InputFileBig, InputFileLocation, InputPeerPhotoFileLocation, InputPhotoFileLocation, TypeInputFile, ) from mautrix.appservice import IntentAPI from mautrix.types import ContentURI, EncryptedFile from mautrix.util.logging import TraceLogger from ..db import TelegramFile as DBTelegramFile from ..tgclient import MautrixTelegramClient try: from mautrix.crypto.attachments import async_encrypt_attachment except ImportError: async_encrypt_attachment = None log: TraceLogger = cast(TraceLogger, logging.getLogger("mau.util")) TypeLocation = Union[ Document, InputDocumentFileLocation, InputPeerPhotoFileLocation, InputFileLocation, InputPhotoFileLocation, ] class DownloadSender: sender: MTProtoSender request: GetFileRequest remaining: int stride: int def __init__( self, sender: MTProtoSender, file: TypeLocation, offset: int, limit: int, stride: int, count: int, ) -> None: self.sender = sender self.request = GetFileRequest(file, offset=offset, limit=limit) self.stride = stride self.remaining = count async def next(self) -> bytes | None: if not self.remaining: return None result = await self.sender.send(self.request) self.remaining -= 1 self.request.offset += self.stride return result.bytes def disconnect(self) -> Awaitable[None]: return self.sender.disconnect() class UploadSender: sender: MTProtoSender request: SaveFilePartRequest < SaveBigFilePartRequest part_count: int stride: int previous: asyncio.Task | None loop: asyncio.AbstractEventLoop def __init__( self, sender: MTProtoSender, file_id: int, part_count: int, big: bool, index: int, stride: int, loop: asyncio.AbstractEventLoop, ) -> None: self.sender = sender self.part_count = part_count if big: self.request = SaveBigFilePartRequest(file_id, index, part_count, b"") else: self.request = SaveFilePartRequest(file_id, index, b"") self.stride = stride self.previous = None self.loop = loop async def next(self, data: bytes) -> None: if self.previous: await self.previous self.previous = asyncio.create_task(self._next(data)) async def _next(self, data: bytes) -> None: self.request.bytes = data log.trace( f"Sending file part {self.request.file_part}/{self.part_count}" f" with {len(data)} bytes" ) await self.sender.send(self.request) self.request.file_part += self.stride async def disconnect(self) -> None: if self.previous: await self.previous return await self.sender.disconnect() class ParallelTransferrer: client: MautrixTelegramClient loop: asyncio.AbstractEventLoop dc_id: int senders: list[DownloadSender | UploadSender] | None auth_key: AuthKey upload_ticker: int def __init__(self, client: MautrixTelegramClient, dc_id: int | None = None) -> None: self.client = client self.loop = self.client.loop self.dc_id = dc_id or self.client.session.dc_id self.auth_key = ( None if dc_id and self.client.session.dc_id != dc_id else self.client.session.auth_key ) self.senders = None self.upload_ticker = 0 async def _cleanup(self) -> None: await asyncio.gather(*(sender.disconnect() for sender in self.senders)) self.senders = None @staticmethod def _get_connection_count( file_size: int, max_count: int = 20, full_size: int = 100 * 1024 * 1024 ) -> int: if file_size > full_size: return max_count return math.ceil((file_size / full_size) * max_count) async def _init_download( self, connections: int, file: TypeLocation, part_count: int, part_size: int ) -> None: minimum, remainder = divmod(part_count, connections) def get_part_count() -> int: nonlocal remainder if remainder > 0: remainder -= 1 return minimum + 1 return minimum # The first cross-DC sender will export+import the authorization, so we always create it # before creating any other senders. self.senders = [ await self._create_download_sender( file, 0, part_size, connections * part_size, get_part_count() ), *await asyncio.gather( *( self._create_download_sender( file, i, part_size, connections * part_size, get_part_count() ) for i in range(1, connections) ) ), ] async def _create_download_sender( self, file: TypeLocation, index: int, part_size: int, stride: int, part_count: int ) -> DownloadSender: return DownloadSender( await self._create_sender(), file, index * part_size, part_size, stride, part_count ) async def _init_upload( self, connections: int, file_id: int, part_count: int, big: bool ) -> None: self.senders = [ await self._create_upload_sender(file_id, part_count, big, 0, connections), *await asyncio.gather( *( self._create_upload_sender(file_id, part_count, big, i, connections) for i in range(1, connections) ) ), ] async def _create_upload_sender( self, file_id: int, part_count: int, big: bool, index: int, stride: int ) -> UploadSender: return UploadSender( await self._create_sender(), file_id, part_count, big, index, stride, loop=self.loop ) async def _create_sender(self) -> MTProtoSender: dc = await self.client._get_dc(self.dc_id) sender = MTProtoSender(self.auth_key, loggers=self.client._log) await sender.connect( self.client._connection( dc.ip_address, dc.port, dc.id, loggers=self.client._log, proxy=self.client._proxy ) ) if not self.auth_key: log.debug(f"Exporting auth to DC {self.dc_id}") auth = await self.client(ExportAuthorizationRequest(self.dc_id)) self.client._init_request.query = ImportAuthorizationRequest( id=auth.id, bytes=auth.bytes ) req = InvokeWithLayerRequest(LAYER, self.client._init_request) await sender.send(req) self.auth_key = sender.auth_key return sender async def init_upload( self, file_id: int, file_size: int, part_size_kb: float | None = None, connection_count: int | None = None, ) -> tuple[int, int, bool]: connection_count = connection_count or self._get_connection_count(file_size) part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024 part_count = (file_size + part_size - 1) // part_size is_large = file_size > 10 * 1024 * 1024 await self._init_upload(connection_count, file_id, part_count, is_large) return part_size, part_count, is_large async def upload(self, part: bytes) -> None: await self.senders[self.upload_ticker].next(part) self.upload_ticker = (self.upload_ticker + 1) % len(self.senders) async def finish_upload(self) -> None: await self._cleanup() async def download( self, file: TypeLocation, file_size: int, part_size_kb: float | None = None, connection_count: int | None = None, ) -> AsyncGenerator[bytes, None]: connection_count = connection_count or self._get_connection_count(file_size) part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024 part_count = math.ceil(file_size / part_size) log.debug( f"Starting parallel download: {connection_count} {part_size} {part_count} {file!s}" ) await self._init_download(connection_count, file, part_count, part_size) part = 0 while part < part_count: tasks = [] for sender in self.senders: tasks.append(asyncio.create_task(sender.next())) for task in tasks: data = await task if not data: break yield data part += 1 log.trace(f"Part {part} downloaded") log.debug("Parallel download finished, cleaning up connections") await self._cleanup() parallel_transfer_locks: defaultdict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) async def parallel_transfer_to_matrix( client: MautrixTelegramClient, intent: IntentAPI, loc_id: str, location: TypeLocation, filename: str, encrypt: bool, parallel_id: int, ) -> DBTelegramFile: size = location.size mime_type = location.mime_type dc_id, location = utils.get_input_location(location) # We lock the transfers because telegram has connection count limits async with parallel_transfer_locks[parallel_id]: downloader = ParallelTransferrer(client, dc_id) data = downloader.download(location, size) decryption_info = None up_mime_type = mime_type if encrypt and async_encrypt_attachment: async def encrypted(stream): nonlocal decryption_info async for chunk in async_encrypt_attachment(stream): if isinstance(chunk, EncryptedFile): decryption_info = chunk else: yield chunk data = encrypted(data) up_mime_type = "application/octet-stream" content_uri = await intent.upload_media( data, mime_type=up_mime_type, filename=filename, size=size if not encrypt else None ) if decryption_info: decryption_info.url = content_uri return DBTelegramFile( id=loc_id, mxc=content_uri, mime_type=mime_type, was_converted=False, timestamp=int(time.time()), size=size, width=None, height=None, decryption_info=decryption_info, ) async def _internal_transfer_to_telegram( client: MautrixTelegramClient, response: ClientResponse ) -> tuple[TypeInputFile, int]: file_id = helpers.generate_random_long() file_size = response.content_length hash_md5 = hashlib.md5() uploader = ParallelTransferrer(client) part_size, part_count, is_large = await uploader.init_upload(file_id, file_size) buffer = bytearray() async for data in response.content: if not is_large: hash_md5.update(data) if len(buffer) == 0 and len(data) == part_size: await uploader.upload(data) continue new_len = len(buffer) + len(data) if new_len >= part_size: cutoff = part_size - len(buffer) buffer.extend(data[:cutoff]) await uploader.upload(bytes(buffer)) buffer.clear() buffer.extend(data[cutoff:]) else: buffer.extend(data) if len(buffer) > 0: await uploader.upload(bytes(buffer)) await uploader.finish_upload() if is_large: return InputFileBig(file_id, part_count, "upload"), file_size else: return InputFile(file_id, part_count, "upload", hash_md5.hexdigest()), file_size async def parallel_transfer_to_telegram( client: MautrixTelegramClient, intent: IntentAPI, uri: ContentURI, parallel_id: int ) -> tuple[TypeInputFile, int]: url = intent.api.get_download_url(uri) async with parallel_transfer_locks[parallel_id]: async with intent.api.session.get(url) as response: return await _internal_transfer_to_telegram(client, response)