gotd: add time synchronization

This commit is contained in:
Tulir Asokan
2026-02-26 18:24:48 +02:00
parent 93fe3cb0ea
commit 6af986ded5
15 changed files with 110 additions and 48 deletions

View File

@@ -2,6 +2,7 @@ package exchange
import (
"io"
"time"
"go.uber.org/zap"
@@ -23,4 +24,6 @@ type ClientExchangeResult struct {
AuthKey crypto.AuthKey
SessionID int64
ServerSalt int64
ServerTimeOffset time.Duration
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"math/big"
"time"
"github.com/go-faster/errors"
"go.uber.org/zap"
@@ -118,6 +119,7 @@ Loop:
EncryptedData: encryptedData,
}
c.log.Debug("Sending ReqDHParamsRequest")
reqStart := c.clock.Now()
if err := c.writeUnencrypted(ctx, b, reqDHParams); err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "write ReqDHParamsRequest")
}
@@ -126,6 +128,7 @@ Loop:
if err := c.conn.Recv(ctx, b); err != nil {
return ClientExchangeResult{}, errors.Wrap(err, "read ServerDHParams message")
}
roundtripDuration := c.clock.Now().Sub(reqStart)
c.log.Debug("Received server ServerDHParams")
var plaintextMsg proto.UnencryptedMessage
@@ -262,6 +265,8 @@ Loop:
AuthKey: crypto.AuthKey{Value: key, ID: authKeyID},
SessionID: sessionID,
ServerSalt: serverSalt,
ServerTimeOffset: time.Unix(int64(innerData.ServerTime), 0).Sub(reqStart.Add(roundtripDuration / 2)),
}, nil
case *mt.DhGenRetry: // dh_gen_retry#46dc1fb9
return ClientExchangeResult{}, errors.Errorf("retry required: %x", v.NewNonceHash2)

View File

@@ -31,6 +31,7 @@ type Handler interface {
// MessageIDSource is message id generator.
type MessageIDSource interface {
New(t proto.MessageType) int64
Reset()
}
// MessageBuf is message id buffer.
@@ -73,6 +74,8 @@ type Conn struct {
salt int64
sessionID int64
serverTimeOffset time.Duration
// server salts fetched by getSalts.
salts salts.Salts
@@ -127,7 +130,6 @@ func New(dialer Dialer, opt Options) *Conn {
rand: opt.Random,
cipher: opt.Cipher,
log: opt.Logger,
messageID: opt.MessageID,
messageIDBuf: proto.NewMessageIDBuf(100),
ackSendChan: make(chan int64),
@@ -155,6 +157,7 @@ func New(dialer Dialer, opt Options) *Conn {
saltFetchInterval: opt.SaltFetchInterval,
getTimeout: opt.RequestTimeout,
}
conn.messageID = proto.NewMessageIDGen(conn.TimeWithOffset)
if conn.rpc == nil {
conn.rpc = rpc.New(conn.writeContentMessage, rpc.Options{
Logger: opt.Logger.Named("rpc"),
@@ -218,3 +221,40 @@ func (c *Conn) Run(ctx context.Context, f func(ctx context.Context) error) error
}
return nil
}
func (c *Conn) setServerTimeOffset(offset time.Duration) {
if offset == 0 {
offset = 1
}
c.sessionMux.Lock()
c.serverTimeOffset = offset
c.sessionMux.Unlock()
if offset > 10*time.Second || offset < -10*time.Second {
c.log.Warn("Updated server time offset (high)", zap.Duration("offset", offset))
} else {
c.log.Info("Updated server time offset", zap.Duration("offset", offset))
}
}
func (c *Conn) hasServerTimeOffset() bool {
c.sessionMux.RLock()
has := c.serverTimeOffset != 0
c.sessionMux.RUnlock()
return has
}
func (c *Conn) TimeWithOffset() (t time.Time) {
c.sessionMux.RLock()
t = c.clock.Now().Add(c.serverTimeOffset)
c.sessionMux.RUnlock()
return
}
func (c *Conn) altTimeWithOffset() (t time.Time) {
c.sessionMux.RLock()
if c.serverTimeOffset != 0 {
t = c.clock.Now().Add(c.serverTimeOffset)
}
c.sessionMux.RUnlock()
return
}

View File

@@ -9,7 +9,6 @@ import (
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
@@ -44,10 +43,9 @@ func newTestClient(h testHandler, opts ...testClientOption) *Conn {
}, rpc.Options{})
opt := Options{
Logger: zap.NewNop(),
Random: rand.New(rand.NewSource(1)),
Key: crypto.Key{}.WithID(),
MessageID: proto.NewMessageIDGen(time.Now),
Logger: zap.NewNop(),
Random: rand.New(rand.NewSource(1)),
Key: crypto.Key{}.WithID(),
engine: engine,
}

View File

@@ -84,10 +84,17 @@ func (c *Conn) createAuthKey(ctx context.Context) error {
}
c.sessionMux.Lock()
c.serverTimeOffset = r.ServerTimeOffset
if c.serverTimeOffset == 0 {
// Creating an auth key always calculates the offset and it should never be 0 in practice,
// but default to 1 just in case
c.serverTimeOffset = 1
}
c.authKey = r.AuthKey
c.sessionID = r.SessionID
c.salt = r.ServerSalt
c.sessionMux.Unlock()
c.log.Info("Created auth key", zap.Duration("server_time_offset", r.ServerTimeOffset))
return nil
}

View File

@@ -7,11 +7,13 @@ import (
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
)
type badMessageError struct {
Code int
NewSalt int64
Code int
NewSalt int64
TimeResynced bool
}
const (
@@ -40,7 +42,8 @@ func (c badMessageError) Error() string {
return description
}
func (c *Conn) handleBadMsg(b *bin.Buffer) error {
func (c *Conn) handleBadMsg(msgID int64, b *bin.Buffer) error {
now := c.clock.Now()
id, err := b.PeekID()
if err != nil {
return err
@@ -51,8 +54,16 @@ func (c *Conn) handleBadMsg(b *bin.Buffer) error {
if err := bad.Decode(b); err != nil {
return err
}
var resynced bool
if !c.hasServerTimeOffset() && (bad.ErrorCode == codeMessageIDTooLow || bad.ErrorCode == codeMessageIDTooHigh) {
created := proto.MessageID(msgID).Time()
c.setServerTimeOffset(created.Sub(now))
c.messageID.Reset()
c.updateSalt()
resynced = true
}
c.rpc.NotifyError(bad.BadMsgID, &badMessageError{Code: bad.ErrorCode})
c.rpc.NotifyError(bad.BadMsgID, &badMessageError{Code: bad.ErrorCode, TimeResynced: resynced})
return nil
case mt.BadServerSaltTypeID:
var bad mt.BadServerSalt

View File

@@ -19,7 +19,7 @@ func (c *Conn) handleMessage(msgID int64, b *bin.Buffer) error {
case mt.NewSessionCreatedTypeID:
return c.handleSessionCreated(b)
case mt.BadMsgNotificationTypeID, mt.BadServerSaltTypeID:
return c.handleBadMsg(b)
return c.handleBadMsg(msgID, b)
case mt.FutureSaltsTypeID:
return c.handleFutureSalts(b)
case proto.MessageContainerTypeID:

View File

@@ -24,14 +24,7 @@ func (c *Conn) handleSessionCreated(b *bin.Buffer) error {
zap.Time("first_msg_time", created.Local()),
)
if (created.Before(now) && now.Sub(created) > maxPast) || created.Sub(now) > maxFuture {
c.log.Warn("Local clock needs synchronization",
zap.Time("first_msg_time", created),
zap.Time("local", now),
zap.Duration("time_difference", now.Sub(created)),
)
}
c.setServerTimeOffset(created.Sub(now))
c.storeSalt(s.ServerSalt)
if err := c.handler.OnSession(c.session()); err != nil {
return errors.Wrap(err, "handler.OnSession")

View File

@@ -11,7 +11,6 @@ import (
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
)
@@ -64,9 +63,6 @@ type Options struct {
// If < 0, compression will be disabled.
// If == 0, default value will be used.
CompressThreshold int
// MessageID is message id source. Share source between connection to
// reduce collision probability.
MessageID MessageIDSource
// Clock is current time source. Defaults to system time.
Clock clock.Clock
// Types map, used in verbose logging of incoming message.
@@ -152,9 +148,6 @@ func (opt *Options) setDefaults() {
if opt.Clock == nil {
opt.Clock = clock.System
}
if opt.MessageID == nil {
opt.MessageID = proto.NewMessageIDGen(opt.Clock.Now)
}
if len(opt.PublicKeys) == 0 {
opt.setDefaultPublicKeys()
}

View File

@@ -38,12 +38,14 @@ func checkMessageID(now time.Time, rawID int64) error {
return errors.Wrapf(errRejected, "unexpected type %s", id.Type())
}
created := id.Time()
if created.Before(now) && now.Sub(created) > maxPast {
return errors.Wrap(errRejected, "created too far in past")
}
if created.Sub(now) > maxFuture {
return errors.Wrap(errRejected, "created too far in future")
if !now.IsZero() {
created := id.Time()
if created.Before(now) && now.Sub(created) > maxPast {
return errors.Wrap(errRejected, "created too far in past")
}
if created.Sub(now) > maxFuture {
return errors.Wrap(errRejected, "created too far in future")
}
}
return nil
@@ -60,7 +62,8 @@ func (c *Conn) decryptMessage(b *bin.Buffer) (*crypto.EncryptedMessageData, erro
if msg.SessionID != session.ID {
return nil, errors.Wrapf(errRejected, "invalid session (got %d, expected %d)", msg.SessionID, session.ID)
}
if err := checkMessageID(c.clock.Now(), msg.MessageID); err != nil {
if err := checkMessageID(c.altTimeWithOffset(), msg.MessageID); err != nil {
return nil, errors.Wrapf(err, "bad message id %d", msg.MessageID)
}
if !c.messageIDBuf.Consume(msg.MessageID) {

View File

@@ -8,6 +8,7 @@ import (
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
)
@@ -23,20 +24,29 @@ func (c *Conn) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder
Output: output,
}
if err := c.rpc.Do(ctx, req); err != nil {
for retries := 0; ; retries++ {
var badMsgErr *badMessageError
if errors.As(err, &badMsgErr) && badMsgErr.Code == codeIncorrectServerSalt {
err := c.rpc.Do(ctx, req)
if err == nil || retries >= 2 || !errors.As(err, &badMsgErr) {
return err
} else if badMsgErr.Code == codeIncorrectServerSalt {
// Store salt from server.
c.storeSalt(badMsgErr.NewSalt)
// Reset saved salts to fetch new.
c.salts.Reset()
c.log.Info("Retrying request after updating salt from badMsgErr", zap.Int64("msg_id", req.MsgID))
return c.rpc.Do(ctx, req)
} else if badMsgErr.TimeResynced {
req.MsgID, req.SeqNo = c.nextMsgSeq(true)
c.log.Info("Retrying request after adjusting time offset from badMsgErr",
zap.Int64("old_msg_id", msgID),
zap.Int64("new_msg_id", req.MsgID),
zap.Stringer("old_msg_id_str", proto.MessageID(msgID)),
zap.Stringer("new_msg_id_str", proto.MessageID(req.MsgID)),
)
} else {
return err
}
return err
}
return nil
}
func (c *Conn) dropRPC(req rpc.Request) error {

View File

@@ -23,7 +23,7 @@ func (c *Conn) storeSalt(salt int64) {
}
func (c *Conn) updateSalt() {
salt, ok := c.salts.Get(c.clock.Now().Add(time.Minute * 5))
salt, ok := c.salts.Get(c.TimeWithOffset().Add(time.Minute * 5))
if !ok {
return
}

View File

@@ -152,6 +152,12 @@ func (g *MessageIDGen) New(t MessageType) int64 {
return int64(NewMessageIDNano(g.nano, t))
}
func (g *MessageIDGen) Reset() {
g.mux.Lock()
g.nano = 0
g.mux.Unlock()
}
// NewMessageIDGen creates new message id generator.
//
// Current time will be provided by now() function.

View File

@@ -206,7 +206,6 @@ func NewClient(appID int, appHash string, opt Options) *Client {
RetryInterval: opt.RetryInterval,
MaxRetries: opt.MaxRetries,
CompressThreshold: opt.CompressThreshold,
MessageID: opt.MessageID,
ExchangeTimeout: opt.ExchangeTimeout,
DialTimeout: opt.DialTimeout,
Clock: opt.Clock,

View File

@@ -12,8 +12,6 @@ import (
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
)
@@ -99,8 +97,7 @@ type Options struct {
// Will be sent with session creation request.
Device DeviceConfig
MessageID mtproto.MessageIDSource
Clock clock.Clock
Clock clock.Clock
PingInterval time.Duration
PingTimeout time.Duration
@@ -153,9 +150,6 @@ func (opt *Options) setDefaults() {
if opt.MigrationTimeout == 0 {
opt.MigrationTimeout = time.Second * 15
}
if opt.MessageID == nil {
opt.MessageID = proto.NewMessageIDGen(opt.Clock.Now)
}
if opt.UpdateHandler == nil {
// No updates handler passed, so no sense to subscribe for updates.
// User should explicitly ignore updates using custom UpdateHandler.