mirror of
https://github.com/mautrix/telegram.git
synced 2026-05-17 07:25:46 +03:00
gotd: add time synchronization
This commit is contained in:
@@ -2,6 +2,7 @@ package exchange
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
@@ -23,4 +24,6 @@ type ClientExchangeResult struct {
|
||||
AuthKey crypto.AuthKey
|
||||
SessionID int64
|
||||
ServerSalt int64
|
||||
|
||||
ServerTimeOffset time.Duration
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"go.uber.org/zap"
|
||||
@@ -118,6 +119,7 @@ Loop:
|
||||
EncryptedData: encryptedData,
|
||||
}
|
||||
c.log.Debug("Sending ReqDHParamsRequest")
|
||||
reqStart := c.clock.Now()
|
||||
if err := c.writeUnencrypted(ctx, b, reqDHParams); err != nil {
|
||||
return ClientExchangeResult{}, errors.Wrap(err, "write ReqDHParamsRequest")
|
||||
}
|
||||
@@ -126,6 +128,7 @@ Loop:
|
||||
if err := c.conn.Recv(ctx, b); err != nil {
|
||||
return ClientExchangeResult{}, errors.Wrap(err, "read ServerDHParams message")
|
||||
}
|
||||
roundtripDuration := c.clock.Now().Sub(reqStart)
|
||||
c.log.Debug("Received server ServerDHParams")
|
||||
|
||||
var plaintextMsg proto.UnencryptedMessage
|
||||
@@ -262,6 +265,8 @@ Loop:
|
||||
AuthKey: crypto.AuthKey{Value: key, ID: authKeyID},
|
||||
SessionID: sessionID,
|
||||
ServerSalt: serverSalt,
|
||||
|
||||
ServerTimeOffset: time.Unix(int64(innerData.ServerTime), 0).Sub(reqStart.Add(roundtripDuration / 2)),
|
||||
}, nil
|
||||
case *mt.DhGenRetry: // dh_gen_retry#46dc1fb9
|
||||
return ClientExchangeResult{}, errors.Errorf("retry required: %x", v.NewNonceHash2)
|
||||
|
||||
@@ -31,6 +31,7 @@ type Handler interface {
|
||||
// MessageIDSource is message id generator.
|
||||
type MessageIDSource interface {
|
||||
New(t proto.MessageType) int64
|
||||
Reset()
|
||||
}
|
||||
|
||||
// MessageBuf is message id buffer.
|
||||
@@ -73,6 +74,8 @@ type Conn struct {
|
||||
salt int64
|
||||
sessionID int64
|
||||
|
||||
serverTimeOffset time.Duration
|
||||
|
||||
// server salts fetched by getSalts.
|
||||
salts salts.Salts
|
||||
|
||||
@@ -127,7 +130,6 @@ func New(dialer Dialer, opt Options) *Conn {
|
||||
rand: opt.Random,
|
||||
cipher: opt.Cipher,
|
||||
log: opt.Logger,
|
||||
messageID: opt.MessageID,
|
||||
messageIDBuf: proto.NewMessageIDBuf(100),
|
||||
|
||||
ackSendChan: make(chan int64),
|
||||
@@ -155,6 +157,7 @@ func New(dialer Dialer, opt Options) *Conn {
|
||||
saltFetchInterval: opt.SaltFetchInterval,
|
||||
getTimeout: opt.RequestTimeout,
|
||||
}
|
||||
conn.messageID = proto.NewMessageIDGen(conn.TimeWithOffset)
|
||||
if conn.rpc == nil {
|
||||
conn.rpc = rpc.New(conn.writeContentMessage, rpc.Options{
|
||||
Logger: opt.Logger.Named("rpc"),
|
||||
@@ -218,3 +221,40 @@ func (c *Conn) Run(ctx context.Context, f func(ctx context.Context) error) error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) setServerTimeOffset(offset time.Duration) {
|
||||
if offset == 0 {
|
||||
offset = 1
|
||||
}
|
||||
c.sessionMux.Lock()
|
||||
c.serverTimeOffset = offset
|
||||
c.sessionMux.Unlock()
|
||||
if offset > 10*time.Second || offset < -10*time.Second {
|
||||
c.log.Warn("Updated server time offset (high)", zap.Duration("offset", offset))
|
||||
} else {
|
||||
c.log.Info("Updated server time offset", zap.Duration("offset", offset))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) hasServerTimeOffset() bool {
|
||||
c.sessionMux.RLock()
|
||||
has := c.serverTimeOffset != 0
|
||||
c.sessionMux.RUnlock()
|
||||
return has
|
||||
}
|
||||
|
||||
func (c *Conn) TimeWithOffset() (t time.Time) {
|
||||
c.sessionMux.RLock()
|
||||
t = c.clock.Now().Add(c.serverTimeOffset)
|
||||
c.sessionMux.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) altTimeWithOffset() (t time.Time) {
|
||||
c.sessionMux.RLock()
|
||||
if c.serverTimeOffset != 0 {
|
||||
t = c.clock.Now().Add(c.serverTimeOffset)
|
||||
}
|
||||
c.sessionMux.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
@@ -44,10 +43,9 @@ func newTestClient(h testHandler, opts ...testClientOption) *Conn {
|
||||
}, rpc.Options{})
|
||||
|
||||
opt := Options{
|
||||
Logger: zap.NewNop(),
|
||||
Random: rand.New(rand.NewSource(1)),
|
||||
Key: crypto.Key{}.WithID(),
|
||||
MessageID: proto.NewMessageIDGen(time.Now),
|
||||
Logger: zap.NewNop(),
|
||||
Random: rand.New(rand.NewSource(1)),
|
||||
Key: crypto.Key{}.WithID(),
|
||||
|
||||
engine: engine,
|
||||
}
|
||||
|
||||
@@ -84,10 +84,17 @@ func (c *Conn) createAuthKey(ctx context.Context) error {
|
||||
}
|
||||
|
||||
c.sessionMux.Lock()
|
||||
c.serverTimeOffset = r.ServerTimeOffset
|
||||
if c.serverTimeOffset == 0 {
|
||||
// Creating an auth key always calculates the offset and it should never be 0 in practice,
|
||||
// but default to 1 just in case
|
||||
c.serverTimeOffset = 1
|
||||
}
|
||||
c.authKey = r.AuthKey
|
||||
c.sessionID = r.SessionID
|
||||
c.salt = r.ServerSalt
|
||||
c.sessionMux.Unlock()
|
||||
c.log.Info("Created auth key", zap.Duration("server_time_offset", r.ServerTimeOffset))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,11 +7,13 @@ import (
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
)
|
||||
|
||||
type badMessageError struct {
|
||||
Code int
|
||||
NewSalt int64
|
||||
Code int
|
||||
NewSalt int64
|
||||
TimeResynced bool
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -40,7 +42,8 @@ func (c badMessageError) Error() string {
|
||||
return description
|
||||
}
|
||||
|
||||
func (c *Conn) handleBadMsg(b *bin.Buffer) error {
|
||||
func (c *Conn) handleBadMsg(msgID int64, b *bin.Buffer) error {
|
||||
now := c.clock.Now()
|
||||
id, err := b.PeekID()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -51,8 +54,16 @@ func (c *Conn) handleBadMsg(b *bin.Buffer) error {
|
||||
if err := bad.Decode(b); err != nil {
|
||||
return err
|
||||
}
|
||||
var resynced bool
|
||||
if !c.hasServerTimeOffset() && (bad.ErrorCode == codeMessageIDTooLow || bad.ErrorCode == codeMessageIDTooHigh) {
|
||||
created := proto.MessageID(msgID).Time()
|
||||
c.setServerTimeOffset(created.Sub(now))
|
||||
c.messageID.Reset()
|
||||
c.updateSalt()
|
||||
resynced = true
|
||||
}
|
||||
|
||||
c.rpc.NotifyError(bad.BadMsgID, &badMessageError{Code: bad.ErrorCode})
|
||||
c.rpc.NotifyError(bad.BadMsgID, &badMessageError{Code: bad.ErrorCode, TimeResynced: resynced})
|
||||
return nil
|
||||
case mt.BadServerSaltTypeID:
|
||||
var bad mt.BadServerSalt
|
||||
|
||||
@@ -19,7 +19,7 @@ func (c *Conn) handleMessage(msgID int64, b *bin.Buffer) error {
|
||||
case mt.NewSessionCreatedTypeID:
|
||||
return c.handleSessionCreated(b)
|
||||
case mt.BadMsgNotificationTypeID, mt.BadServerSaltTypeID:
|
||||
return c.handleBadMsg(b)
|
||||
return c.handleBadMsg(msgID, b)
|
||||
case mt.FutureSaltsTypeID:
|
||||
return c.handleFutureSalts(b)
|
||||
case proto.MessageContainerTypeID:
|
||||
|
||||
@@ -24,14 +24,7 @@ func (c *Conn) handleSessionCreated(b *bin.Buffer) error {
|
||||
zap.Time("first_msg_time", created.Local()),
|
||||
)
|
||||
|
||||
if (created.Before(now) && now.Sub(created) > maxPast) || created.Sub(now) > maxFuture {
|
||||
c.log.Warn("Local clock needs synchronization",
|
||||
zap.Time("first_msg_time", created),
|
||||
zap.Time("local", now),
|
||||
zap.Duration("time_difference", now.Sub(created)),
|
||||
)
|
||||
}
|
||||
|
||||
c.setServerTimeOffset(created.Sub(now))
|
||||
c.storeSalt(s.ServerSalt)
|
||||
if err := c.handler.OnSession(c.session()); err != nil {
|
||||
return errors.Wrap(err, "handler.OnSession")
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tmap"
|
||||
)
|
||||
@@ -64,9 +63,6 @@ type Options struct {
|
||||
// If < 0, compression will be disabled.
|
||||
// If == 0, default value will be used.
|
||||
CompressThreshold int
|
||||
// MessageID is message id source. Share source between connection to
|
||||
// reduce collision probability.
|
||||
MessageID MessageIDSource
|
||||
// Clock is current time source. Defaults to system time.
|
||||
Clock clock.Clock
|
||||
// Types map, used in verbose logging of incoming message.
|
||||
@@ -152,9 +148,6 @@ func (opt *Options) setDefaults() {
|
||||
if opt.Clock == nil {
|
||||
opt.Clock = clock.System
|
||||
}
|
||||
if opt.MessageID == nil {
|
||||
opt.MessageID = proto.NewMessageIDGen(opt.Clock.Now)
|
||||
}
|
||||
if len(opt.PublicKeys) == 0 {
|
||||
opt.setDefaultPublicKeys()
|
||||
}
|
||||
|
||||
@@ -38,12 +38,14 @@ func checkMessageID(now time.Time, rawID int64) error {
|
||||
return errors.Wrapf(errRejected, "unexpected type %s", id.Type())
|
||||
}
|
||||
|
||||
created := id.Time()
|
||||
if created.Before(now) && now.Sub(created) > maxPast {
|
||||
return errors.Wrap(errRejected, "created too far in past")
|
||||
}
|
||||
if created.Sub(now) > maxFuture {
|
||||
return errors.Wrap(errRejected, "created too far in future")
|
||||
if !now.IsZero() {
|
||||
created := id.Time()
|
||||
if created.Before(now) && now.Sub(created) > maxPast {
|
||||
return errors.Wrap(errRejected, "created too far in past")
|
||||
}
|
||||
if created.Sub(now) > maxFuture {
|
||||
return errors.Wrap(errRejected, "created too far in future")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -60,7 +62,8 @@ func (c *Conn) decryptMessage(b *bin.Buffer) (*crypto.EncryptedMessageData, erro
|
||||
if msg.SessionID != session.ID {
|
||||
return nil, errors.Wrapf(errRejected, "invalid session (got %d, expected %d)", msg.SessionID, session.ID)
|
||||
}
|
||||
if err := checkMessageID(c.clock.Now(), msg.MessageID); err != nil {
|
||||
|
||||
if err := checkMessageID(c.altTimeWithOffset(), msg.MessageID); err != nil {
|
||||
return nil, errors.Wrapf(err, "bad message id %d", msg.MessageID)
|
||||
}
|
||||
if !c.messageIDBuf.Consume(msg.MessageID) {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/bin"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mt"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/rpc"
|
||||
)
|
||||
|
||||
@@ -23,20 +24,29 @@ func (c *Conn) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder
|
||||
Output: output,
|
||||
}
|
||||
|
||||
if err := c.rpc.Do(ctx, req); err != nil {
|
||||
for retries := 0; ; retries++ {
|
||||
var badMsgErr *badMessageError
|
||||
if errors.As(err, &badMsgErr) && badMsgErr.Code == codeIncorrectServerSalt {
|
||||
err := c.rpc.Do(ctx, req)
|
||||
if err == nil || retries >= 2 || !errors.As(err, &badMsgErr) {
|
||||
return err
|
||||
} else if badMsgErr.Code == codeIncorrectServerSalt {
|
||||
// Store salt from server.
|
||||
c.storeSalt(badMsgErr.NewSalt)
|
||||
// Reset saved salts to fetch new.
|
||||
c.salts.Reset()
|
||||
c.log.Info("Retrying request after updating salt from badMsgErr", zap.Int64("msg_id", req.MsgID))
|
||||
return c.rpc.Do(ctx, req)
|
||||
} else if badMsgErr.TimeResynced {
|
||||
req.MsgID, req.SeqNo = c.nextMsgSeq(true)
|
||||
c.log.Info("Retrying request after adjusting time offset from badMsgErr",
|
||||
zap.Int64("old_msg_id", msgID),
|
||||
zap.Int64("new_msg_id", req.MsgID),
|
||||
zap.Stringer("old_msg_id_str", proto.MessageID(msgID)),
|
||||
zap.Stringer("new_msg_id_str", proto.MessageID(req.MsgID)),
|
||||
)
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) dropRPC(req rpc.Request) error {
|
||||
|
||||
@@ -23,7 +23,7 @@ func (c *Conn) storeSalt(salt int64) {
|
||||
}
|
||||
|
||||
func (c *Conn) updateSalt() {
|
||||
salt, ok := c.salts.Get(c.clock.Now().Add(time.Minute * 5))
|
||||
salt, ok := c.salts.Get(c.TimeWithOffset().Add(time.Minute * 5))
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -152,6 +152,12 @@ func (g *MessageIDGen) New(t MessageType) int64 {
|
||||
return int64(NewMessageIDNano(g.nano, t))
|
||||
}
|
||||
|
||||
func (g *MessageIDGen) Reset() {
|
||||
g.mux.Lock()
|
||||
g.nano = 0
|
||||
g.mux.Unlock()
|
||||
}
|
||||
|
||||
// NewMessageIDGen creates new message id generator.
|
||||
//
|
||||
// Current time will be provided by now() function.
|
||||
|
||||
@@ -206,7 +206,6 @@ func NewClient(appID int, appHash string, opt Options) *Client {
|
||||
RetryInterval: opt.RetryInterval,
|
||||
MaxRetries: opt.MaxRetries,
|
||||
CompressThreshold: opt.CompressThreshold,
|
||||
MessageID: opt.MessageID,
|
||||
ExchangeTimeout: opt.ExchangeTimeout,
|
||||
DialTimeout: opt.DialTimeout,
|
||||
Clock: opt.Clock,
|
||||
|
||||
@@ -12,8 +12,6 @@ import (
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/clock"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/crypto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/exchange"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/mtproto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/proto"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram/dcs"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
)
|
||||
@@ -99,8 +97,7 @@ type Options struct {
|
||||
// Will be sent with session creation request.
|
||||
Device DeviceConfig
|
||||
|
||||
MessageID mtproto.MessageIDSource
|
||||
Clock clock.Clock
|
||||
Clock clock.Clock
|
||||
|
||||
PingInterval time.Duration
|
||||
PingTimeout time.Duration
|
||||
@@ -153,9 +150,6 @@ func (opt *Options) setDefaults() {
|
||||
if opt.MigrationTimeout == 0 {
|
||||
opt.MigrationTimeout = time.Second * 15
|
||||
}
|
||||
if opt.MessageID == nil {
|
||||
opt.MessageID = proto.NewMessageIDGen(opt.Clock.Now)
|
||||
}
|
||||
if opt.UpdateHandler == nil {
|
||||
// No updates handler passed, so no sense to subscribe for updates.
|
||||
// User should explicitly ignore updates using custom UpdateHandler.
|
||||
|
||||
Reference in New Issue
Block a user