You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

218 lines
5.6 KiB

package mtproto
import (
"context"
"io"
"sync"
"time"
"github.com/go-faster/errors"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/gotd/td/bin"
"github.com/gotd/td/clock"
"github.com/gotd/td/internal/crypto"
"github.com/gotd/td/internal/exchange"
"github.com/gotd/td/internal/mtproto/salts"
"github.com/gotd/td/internal/proto"
"github.com/gotd/td/internal/rpc"
"github.com/gotd/td/internal/tdsync"
"github.com/gotd/td/internal/tmap"
"github.com/gotd/td/transport"
)
// Handler will be called on received message from Telegram.
type Handler interface {
OnMessage(b *bin.Buffer) error
OnSession(session Session) error
}
// MessageIDSource is message id generator.
type MessageIDSource interface {
New(t proto.MessageType) int64
}
// MessageBuf is message id buffer.
type MessageBuf interface {
Consume(id int64) bool
}
// Cipher handles message encryption and decryption.
type Cipher interface {
DecryptFromBuffer(k crypto.AuthKey, buf *bin.Buffer) (*crypto.EncryptedMessageData, error)
Encrypt(key crypto.AuthKey, data crypto.EncryptedMessageData, b *bin.Buffer) error
}
// Dialer is an abstraction for MTProto transport connection creator.
type Dialer func(ctx context.Context) (transport.Conn, error)
// Conn represents a MTProto client to Telegram.
type Conn struct {
dcID int
dialer Dialer
conn transport.Conn
handler Handler
rpc *rpc.Engine
rsaPublicKeys []exchange.PublicKey
types *tmap.Map
// Wrappers for external world, like current time, logs or PRNG.
// Should be immutable.
clock clock.Clock
rand io.Reader
cipher Cipher
log *zap.Logger
messageID MessageIDSource
messageIDBuf MessageBuf // replay attack protection
// use session() to access authKey, salt or sessionID.
sessionMux sync.RWMutex
authKey crypto.AuthKey
salt int64
sessionID int64
// server salts fetched by getSalts.
salts salts.Salts
// sentContentMessages is count of created content messages, used to
// compute sequence number within session.
sentContentMessages int32
reqMux sync.Mutex
// ackSendChan is queue for outgoing message id's that require waiting for
// ack from server.
ackSendChan chan int64
ackBatchSize int
ackInterval time.Duration
// callbacks for ping results.
// Key is ping id.
ping map[int64]chan struct{}
pingMux sync.Mutex
// pingTimeout sets ping_delay_disconnect delay.
pingTimeout time.Duration
// pingInterval is duration between ping_delay_disconnect request.
pingInterval time.Duration
// gotSession is a signal channel for wait for handleSessionCreated message.
gotSession *tdsync.Ready
// exchangeLock locks write calls during key exchange.
exchangeLock sync.RWMutex
// compressThreshold is a threshold in bytes to determine that message
// is large enough to be compressed using gzip.
compressThreshold int
dialTimeout time.Duration
exchangeTimeout time.Duration
saltFetchInterval time.Duration
getTimeout func(req uint32) time.Duration
// Ensure Run once.
ran atomic.Bool
}
// New creates new unstarted connection.
func New(dialer Dialer, opt Options) *Conn {
// Set default values, if user does not set.
opt.setDefaults()
conn := &Conn{
dcID: opt.DC,
dialer: dialer,
clock: opt.Clock,
rand: opt.Random,
cipher: opt.Cipher,
log: opt.Logger,
messageID: opt.MessageID,
messageIDBuf: proto.NewMessageIDBuf(100),
ackSendChan: make(chan int64),
ackInterval: opt.AckInterval,
ackBatchSize: opt.AckBatchSize,
rsaPublicKeys: opt.PublicKeys,
handler: opt.Handler,
types: opt.Types,
authKey: opt.Key,
salt: opt.Salt,
ping: map[int64]chan struct{}{},
pingTimeout: opt.PingTimeout,
pingInterval: opt.PingInterval,
gotSession: tdsync.NewReady(),
rpc: opt.engine,
compressThreshold: opt.CompressThreshold,
dialTimeout: opt.DialTimeout,
exchangeTimeout: opt.ExchangeTimeout,
saltFetchInterval: opt.SaltFetchInterval,
getTimeout: opt.RequestTimeout,
}
if conn.rpc == nil {
conn.rpc = rpc.New(conn.writeContentMessage, rpc.Options{
Logger: opt.Logger.Named("rpc"),
RetryInterval: opt.RetryInterval,
MaxRetries: opt.MaxRetries,
Clock: opt.Clock,
DropHandler: conn.dropRPC,
})
}
return conn
}
// handleClose closes rpc engine and underlying connection on context done.
func (c *Conn) handleClose(ctx context.Context) error {
<-ctx.Done()
c.log.Debug("Closing")
// Close RPC Engine.
c.rpc.ForceClose()
// Close connection.
if err := c.conn.Close(); err != nil {
c.log.Debug("Failed to cleanup connection", zap.Error(err))
}
return nil
}
// Run initializes MTProto connection to server and blocks until disconnection.
//
// When connection is ready, Handler.OnSession is called.
func (c *Conn) Run(ctx context.Context, f func(ctx context.Context) error) error {
// Starting connection.
//
// This will send initial packet to telegram and perform key exchange
// if needed.
if c.ran.Swap(true) {
return errors.New("do Run on closed connection")
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
c.log.Debug("Run: start")
defer c.log.Debug("Run: end")
if err := c.connect(ctx); err != nil {
return errors.Wrap(err, "start")
}
{
// All goroutines are bound to current call.
g := tdsync.NewLogGroup(ctx, c.log.Named("group"))
g.Go("handleClose", c.handleClose)
g.Go("pingLoop", c.pingLoop)
g.Go("ackLoop", c.ackLoop)
g.Go("saltsLoop", c.saltLoop)
g.Go("userCallback", f)
g.Go("readLoop", c.readLoop)
if err := g.Wait(); err != nil {
return errors.Wrap(err, "group")
}
}
return nil
}