You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
219 lines
5.6 KiB
219 lines
5.6 KiB
3 years ago
|
package mtproto
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"io"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/go-faster/errors"
|
||
|
"go.uber.org/atomic"
|
||
|
"go.uber.org/zap"
|
||
|
|
||
|
"github.com/gotd/td/bin"
|
||
|
"github.com/gotd/td/clock"
|
||
|
"github.com/gotd/td/internal/crypto"
|
||
|
"github.com/gotd/td/internal/exchange"
|
||
|
"github.com/gotd/td/internal/mtproto/salts"
|
||
|
"github.com/gotd/td/internal/proto"
|
||
|
"github.com/gotd/td/internal/rpc"
|
||
|
"github.com/gotd/td/internal/tdsync"
|
||
|
"github.com/gotd/td/internal/tmap"
|
||
|
"github.com/gotd/td/transport"
|
||
|
)
|
||
|
|
||
|
// Handler will be called on received message from Telegram.
|
||
|
type Handler interface {
|
||
|
OnMessage(b *bin.Buffer) error
|
||
|
OnSession(session Session) error
|
||
|
}
|
||
|
|
||
|
// MessageIDSource is message id generator.
|
||
|
type MessageIDSource interface {
|
||
|
New(t proto.MessageType) int64
|
||
|
}
|
||
|
|
||
|
// MessageBuf is message id buffer.
|
||
|
type MessageBuf interface {
|
||
|
Consume(id int64) bool
|
||
|
}
|
||
|
|
||
|
// Cipher handles message encryption and decryption.
|
||
|
type Cipher interface {
|
||
|
DecryptFromBuffer(k crypto.AuthKey, buf *bin.Buffer) (*crypto.EncryptedMessageData, error)
|
||
|
Encrypt(key crypto.AuthKey, data crypto.EncryptedMessageData, b *bin.Buffer) error
|
||
|
}
|
||
|
|
||
|
// Dialer is an abstraction for MTProto transport connection creator.
|
||
|
type Dialer func(ctx context.Context) (transport.Conn, error)
|
||
|
|
||
|
// Conn represents a MTProto client to Telegram.
|
||
|
type Conn struct {
|
||
|
dcID int
|
||
|
|
||
|
dialer Dialer
|
||
|
conn transport.Conn
|
||
|
handler Handler
|
||
|
rpc *rpc.Engine
|
||
|
rsaPublicKeys []exchange.PublicKey
|
||
|
types *tmap.Map
|
||
|
|
||
|
// Wrappers for external world, like current time, logs or PRNG.
|
||
|
// Should be immutable.
|
||
|
clock clock.Clock
|
||
|
rand io.Reader
|
||
|
cipher Cipher
|
||
|
log *zap.Logger
|
||
|
messageID MessageIDSource
|
||
|
messageIDBuf MessageBuf // replay attack protection
|
||
|
|
||
|
// use session() to access authKey, salt or sessionID.
|
||
|
sessionMux sync.RWMutex
|
||
|
authKey crypto.AuthKey
|
||
|
salt int64
|
||
|
sessionID int64
|
||
|
|
||
|
// server salts fetched by getSalts.
|
||
|
salts salts.Salts
|
||
|
|
||
|
// sentContentMessages is count of created content messages, used to
|
||
|
// compute sequence number within session.
|
||
|
sentContentMessages int32
|
||
|
reqMux sync.Mutex
|
||
|
|
||
|
// ackSendChan is queue for outgoing message id's that require waiting for
|
||
|
// ack from server.
|
||
|
ackSendChan chan int64
|
||
|
ackBatchSize int
|
||
|
ackInterval time.Duration
|
||
|
|
||
|
// callbacks for ping results.
|
||
|
// Key is ping id.
|
||
|
ping map[int64]chan struct{}
|
||
|
pingMux sync.Mutex
|
||
|
// pingTimeout sets ping_delay_disconnect delay.
|
||
|
pingTimeout time.Duration
|
||
|
// pingInterval is duration between ping_delay_disconnect request.
|
||
|
pingInterval time.Duration
|
||
|
|
||
|
// gotSession is a signal channel for wait for handleSessionCreated message.
|
||
|
gotSession *tdsync.Ready
|
||
|
|
||
|
// exchangeLock locks write calls during key exchange.
|
||
|
exchangeLock sync.RWMutex
|
||
|
|
||
|
// compressThreshold is a threshold in bytes to determine that message
|
||
|
// is large enough to be compressed using gzip.
|
||
|
compressThreshold int
|
||
|
dialTimeout time.Duration
|
||
|
exchangeTimeout time.Duration
|
||
|
saltFetchInterval time.Duration
|
||
|
getTimeout func(req uint32) time.Duration
|
||
|
// Ensure Run once.
|
||
|
ran atomic.Bool
|
||
|
}
|
||
|
|
||
|
// New creates new unstarted connection.
|
||
|
func New(dialer Dialer, opt Options) *Conn {
|
||
|
// Set default values, if user does not set.
|
||
|
opt.setDefaults()
|
||
|
|
||
|
conn := &Conn{
|
||
|
dcID: opt.DC,
|
||
|
|
||
|
dialer: dialer,
|
||
|
clock: opt.Clock,
|
||
|
rand: opt.Random,
|
||
|
cipher: opt.Cipher,
|
||
|
log: opt.Logger,
|
||
|
messageID: opt.MessageID,
|
||
|
messageIDBuf: proto.NewMessageIDBuf(100),
|
||
|
|
||
|
ackSendChan: make(chan int64),
|
||
|
ackInterval: opt.AckInterval,
|
||
|
ackBatchSize: opt.AckBatchSize,
|
||
|
|
||
|
rsaPublicKeys: opt.PublicKeys,
|
||
|
handler: opt.Handler,
|
||
|
types: opt.Types,
|
||
|
|
||
|
authKey: opt.Key,
|
||
|
salt: opt.Salt,
|
||
|
|
||
|
ping: map[int64]chan struct{}{},
|
||
|
pingTimeout: opt.PingTimeout,
|
||
|
pingInterval: opt.PingInterval,
|
||
|
|
||
|
gotSession: tdsync.NewReady(),
|
||
|
|
||
|
rpc: opt.engine,
|
||
|
compressThreshold: opt.CompressThreshold,
|
||
|
dialTimeout: opt.DialTimeout,
|
||
|
exchangeTimeout: opt.ExchangeTimeout,
|
||
|
saltFetchInterval: opt.SaltFetchInterval,
|
||
|
getTimeout: opt.RequestTimeout,
|
||
|
}
|
||
|
if conn.rpc == nil {
|
||
|
conn.rpc = rpc.New(conn.writeContentMessage, rpc.Options{
|
||
|
Logger: opt.Logger.Named("rpc"),
|
||
|
RetryInterval: opt.RetryInterval,
|
||
|
MaxRetries: opt.MaxRetries,
|
||
|
Clock: opt.Clock,
|
||
|
DropHandler: conn.dropRPC,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
return conn
|
||
|
}
|
||
|
|
||
|
// handleClose closes rpc engine and underlying connection on context done.
|
||
|
func (c *Conn) handleClose(ctx context.Context) error {
|
||
|
<-ctx.Done()
|
||
|
c.log.Debug("Closing")
|
||
|
|
||
|
// Close RPC Engine.
|
||
|
c.rpc.ForceClose()
|
||
|
// Close connection.
|
||
|
if err := c.conn.Close(); err != nil {
|
||
|
c.log.Debug("Failed to cleanup connection", zap.Error(err))
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Run initializes MTProto connection to server and blocks until disconnection.
|
||
|
//
|
||
|
// When connection is ready, Handler.OnSession is called.
|
||
|
func (c *Conn) Run(ctx context.Context, f func(ctx context.Context) error) error {
|
||
|
// Starting connection.
|
||
|
//
|
||
|
// This will send initial packet to telegram and perform key exchange
|
||
|
// if needed.
|
||
|
if c.ran.Swap(true) {
|
||
|
return errors.New("do Run on closed connection")
|
||
|
}
|
||
|
|
||
|
ctx, cancel := context.WithCancel(ctx)
|
||
|
defer cancel()
|
||
|
|
||
|
c.log.Debug("Run: start")
|
||
|
defer c.log.Debug("Run: end")
|
||
|
if err := c.connect(ctx); err != nil {
|
||
|
return errors.Wrap(err, "start")
|
||
|
}
|
||
|
{
|
||
|
// All goroutines are bound to current call.
|
||
|
g := tdsync.NewLogGroup(ctx, c.log.Named("group"))
|
||
|
g.Go("handleClose", c.handleClose)
|
||
|
g.Go("pingLoop", c.pingLoop)
|
||
|
g.Go("ackLoop", c.ackLoop)
|
||
|
g.Go("saltsLoop", c.saltLoop)
|
||
|
g.Go("userCallback", f)
|
||
|
g.Go("readLoop", c.readLoop)
|
||
|
|
||
|
if err := g.Wait(); err != nil {
|
||
|
return errors.Wrap(err, "group")
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|