You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
218 lines
5.6 KiB
218 lines
5.6 KiB
package mtproto |
|
|
|
import ( |
|
"context" |
|
"io" |
|
"sync" |
|
"time" |
|
|
|
"github.com/go-faster/errors" |
|
"go.uber.org/atomic" |
|
"go.uber.org/zap" |
|
|
|
"github.com/gotd/td/bin" |
|
"github.com/gotd/td/clock" |
|
"github.com/gotd/td/internal/crypto" |
|
"github.com/gotd/td/internal/exchange" |
|
"github.com/gotd/td/internal/mtproto/salts" |
|
"github.com/gotd/td/internal/proto" |
|
"github.com/gotd/td/internal/rpc" |
|
"github.com/gotd/td/internal/tdsync" |
|
"github.com/gotd/td/internal/tmap" |
|
"github.com/gotd/td/transport" |
|
) |
|
|
|
// Handler will be called on received message from Telegram. |
|
type Handler interface { |
|
OnMessage(b *bin.Buffer) error |
|
OnSession(session Session) error |
|
} |
|
|
|
// MessageIDSource is message id generator. |
|
type MessageIDSource interface { |
|
New(t proto.MessageType) int64 |
|
} |
|
|
|
// MessageBuf is message id buffer. |
|
type MessageBuf interface { |
|
Consume(id int64) bool |
|
} |
|
|
|
// Cipher handles message encryption and decryption. |
|
type Cipher interface { |
|
DecryptFromBuffer(k crypto.AuthKey, buf *bin.Buffer) (*crypto.EncryptedMessageData, error) |
|
Encrypt(key crypto.AuthKey, data crypto.EncryptedMessageData, b *bin.Buffer) error |
|
} |
|
|
|
// Dialer is an abstraction for MTProto transport connection creator. |
|
type Dialer func(ctx context.Context) (transport.Conn, error) |
|
|
|
// Conn represents a MTProto client to Telegram. |
|
type Conn struct { |
|
dcID int |
|
|
|
dialer Dialer |
|
conn transport.Conn |
|
handler Handler |
|
rpc *rpc.Engine |
|
rsaPublicKeys []exchange.PublicKey |
|
types *tmap.Map |
|
|
|
// Wrappers for external world, like current time, logs or PRNG. |
|
// Should be immutable. |
|
clock clock.Clock |
|
rand io.Reader |
|
cipher Cipher |
|
log *zap.Logger |
|
messageID MessageIDSource |
|
messageIDBuf MessageBuf // replay attack protection |
|
|
|
// use session() to access authKey, salt or sessionID. |
|
sessionMux sync.RWMutex |
|
authKey crypto.AuthKey |
|
salt int64 |
|
sessionID int64 |
|
|
|
// server salts fetched by getSalts. |
|
salts salts.Salts |
|
|
|
// sentContentMessages is count of created content messages, used to |
|
// compute sequence number within session. |
|
sentContentMessages int32 |
|
reqMux sync.Mutex |
|
|
|
// ackSendChan is queue for outgoing message id's that require waiting for |
|
// ack from server. |
|
ackSendChan chan int64 |
|
ackBatchSize int |
|
ackInterval time.Duration |
|
|
|
// callbacks for ping results. |
|
// Key is ping id. |
|
ping map[int64]chan struct{} |
|
pingMux sync.Mutex |
|
// pingTimeout sets ping_delay_disconnect delay. |
|
pingTimeout time.Duration |
|
// pingInterval is duration between ping_delay_disconnect request. |
|
pingInterval time.Duration |
|
|
|
// gotSession is a signal channel for wait for handleSessionCreated message. |
|
gotSession *tdsync.Ready |
|
|
|
// exchangeLock locks write calls during key exchange. |
|
exchangeLock sync.RWMutex |
|
|
|
// compressThreshold is a threshold in bytes to determine that message |
|
// is large enough to be compressed using gzip. |
|
compressThreshold int |
|
dialTimeout time.Duration |
|
exchangeTimeout time.Duration |
|
saltFetchInterval time.Duration |
|
getTimeout func(req uint32) time.Duration |
|
// Ensure Run once. |
|
ran atomic.Bool |
|
} |
|
|
|
// New creates new unstarted connection. |
|
func New(dialer Dialer, opt Options) *Conn { |
|
// Set default values, if user does not set. |
|
opt.setDefaults() |
|
|
|
conn := &Conn{ |
|
dcID: opt.DC, |
|
|
|
dialer: dialer, |
|
clock: opt.Clock, |
|
rand: opt.Random, |
|
cipher: opt.Cipher, |
|
log: opt.Logger, |
|
messageID: opt.MessageID, |
|
messageIDBuf: proto.NewMessageIDBuf(100), |
|
|
|
ackSendChan: make(chan int64), |
|
ackInterval: opt.AckInterval, |
|
ackBatchSize: opt.AckBatchSize, |
|
|
|
rsaPublicKeys: opt.PublicKeys, |
|
handler: opt.Handler, |
|
types: opt.Types, |
|
|
|
authKey: opt.Key, |
|
salt: opt.Salt, |
|
|
|
ping: map[int64]chan struct{}{}, |
|
pingTimeout: opt.PingTimeout, |
|
pingInterval: opt.PingInterval, |
|
|
|
gotSession: tdsync.NewReady(), |
|
|
|
rpc: opt.engine, |
|
compressThreshold: opt.CompressThreshold, |
|
dialTimeout: opt.DialTimeout, |
|
exchangeTimeout: opt.ExchangeTimeout, |
|
saltFetchInterval: opt.SaltFetchInterval, |
|
getTimeout: opt.RequestTimeout, |
|
} |
|
if conn.rpc == nil { |
|
conn.rpc = rpc.New(conn.writeContentMessage, rpc.Options{ |
|
Logger: opt.Logger.Named("rpc"), |
|
RetryInterval: opt.RetryInterval, |
|
MaxRetries: opt.MaxRetries, |
|
Clock: opt.Clock, |
|
DropHandler: conn.dropRPC, |
|
}) |
|
} |
|
|
|
return conn |
|
} |
|
|
|
// handleClose closes rpc engine and underlying connection on context done. |
|
func (c *Conn) handleClose(ctx context.Context) error { |
|
<-ctx.Done() |
|
c.log.Debug("Closing") |
|
|
|
// Close RPC Engine. |
|
c.rpc.ForceClose() |
|
// Close connection. |
|
if err := c.conn.Close(); err != nil { |
|
c.log.Debug("Failed to cleanup connection", zap.Error(err)) |
|
} |
|
return nil |
|
} |
|
|
|
// Run initializes MTProto connection to server and blocks until disconnection. |
|
// |
|
// When connection is ready, Handler.OnSession is called. |
|
func (c *Conn) Run(ctx context.Context, f func(ctx context.Context) error) error { |
|
// Starting connection. |
|
// |
|
// This will send initial packet to telegram and perform key exchange |
|
// if needed. |
|
if c.ran.Swap(true) { |
|
return errors.New("do Run on closed connection") |
|
} |
|
|
|
ctx, cancel := context.WithCancel(ctx) |
|
defer cancel() |
|
|
|
c.log.Debug("Run: start") |
|
defer c.log.Debug("Run: end") |
|
if err := c.connect(ctx); err != nil { |
|
return errors.Wrap(err, "start") |
|
} |
|
{ |
|
// All goroutines are bound to current call. |
|
g := tdsync.NewLogGroup(ctx, c.log.Named("group")) |
|
g.Go("handleClose", c.handleClose) |
|
g.Go("pingLoop", c.pingLoop) |
|
g.Go("ackLoop", c.ackLoop) |
|
g.Go("saltsLoop", c.saltLoop) |
|
g.Go("userCallback", f) |
|
g.Go("readLoop", c.readLoop) |
|
|
|
if err := g.Wait(); err != nil { |
|
return errors.Wrap(err, "group") |
|
} |
|
} |
|
return nil |
|
}
|
|
|