You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
214 lines
5.8 KiB
214 lines
5.8 KiB
3 years ago
|
package mtproto
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"net"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
"time"
|
||
|
|
||
|
"github.com/go-faster/errors"
|
||
|
"go.uber.org/zap"
|
||
|
|
||
|
"github.com/gotd/td/bin"
|
||
|
"github.com/gotd/td/internal/crypto"
|
||
|
"github.com/gotd/td/internal/proto"
|
||
|
"github.com/gotd/td/internal/proto/codec"
|
||
|
)
|
||
|
|
||
|
// https://core.telegram.org/mtproto/description#message-identifier-msg-id
|
||
|
// A message is rejected over 300 seconds after it is created or 30 seconds
|
||
|
// before it is created (this is needed to protect from replay attacks).
|
||
|
const (
|
||
|
maxPast = time.Second * 300
|
||
|
maxFuture = time.Second * 30
|
||
|
)
|
||
|
|
||
|
// errRejected is returned on invalid message that should not be processed.
|
||
|
var errRejected = errors.New("message rejected")
|
||
|
|
||
|
func checkMessageID(now time.Time, rawID int64) error {
|
||
|
id := proto.MessageID(rawID)
|
||
|
|
||
|
// Check that message is from server.
|
||
|
switch id.Type() {
|
||
|
case proto.MessageFromServer, proto.MessageServerResponse:
|
||
|
// Valid.
|
||
|
default:
|
||
|
return errors.Wrapf(errRejected, "unexpected type %s", id.Type())
|
||
|
}
|
||
|
|
||
|
created := id.Time()
|
||
|
if created.Before(now) && now.Sub(created) > maxPast {
|
||
|
return errors.Wrap(errRejected, "created too far in past")
|
||
|
}
|
||
|
if created.Sub(now) > maxFuture {
|
||
|
return errors.Wrap(errRejected, "created too far in future")
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *Conn) decryptMessage(b *bin.Buffer) (*crypto.EncryptedMessageData, error) {
|
||
|
session := c.session()
|
||
|
msg, err := c.cipher.DecryptFromBuffer(session.Key, b)
|
||
|
if err != nil {
|
||
|
return nil, errors.Wrap(err, "decrypt")
|
||
|
}
|
||
|
|
||
|
// Validating message. This protects from replay attacks.
|
||
|
if msg.SessionID != session.ID {
|
||
|
return nil, errors.Wrapf(errRejected, "invalid session (got %d, expected %d)", msg.SessionID, session.ID)
|
||
|
}
|
||
|
if err := checkMessageID(c.clock.Now(), msg.MessageID); err != nil {
|
||
|
return nil, errors.Wrapf(err, "bad message id %d", msg.MessageID)
|
||
|
}
|
||
|
if !c.messageIDBuf.Consume(msg.MessageID) {
|
||
|
return nil, errors.Wrapf(errRejected, "duplicate or too low message id %d", msg.MessageID)
|
||
|
}
|
||
|
|
||
|
return msg, nil
|
||
|
}
|
||
|
|
||
|
func (c *Conn) consumeMessage(ctx context.Context, buf *bin.Buffer) error {
|
||
|
msg, err := c.decryptMessage(buf)
|
||
|
if errors.Is(err, errRejected) {
|
||
|
c.log.Warn("Ignoring rejected message", zap.Error(err))
|
||
|
return nil
|
||
|
}
|
||
|
if err != nil {
|
||
|
return errors.Wrap(err, "consume message")
|
||
|
}
|
||
|
|
||
|
if err := c.handleMessage(msg.MessageID, &bin.Buffer{Buf: msg.Data()}); err != nil {
|
||
|
// Probably we can return here, but this will shutdown whole
|
||
|
// connection which can be unexpected.
|
||
|
c.log.Warn("Error while handling message", zap.Error(err))
|
||
|
// Sending acknowledge even on error. Client should restore
|
||
|
// from missing updates via explicit pts check and getDiff call.
|
||
|
}
|
||
|
|
||
|
needAck := (msg.SeqNo & 0x01) != 0
|
||
|
if needAck {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return ctx.Err()
|
||
|
case c.ackSendChan <- msg.MessageID:
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *Conn) noUpdates(err error) bool {
|
||
|
// Checking for read timeout.
|
||
|
var syscall *net.OpError
|
||
|
if errors.As(err, &syscall) && syscall.Timeout() {
|
||
|
// We call SetReadDeadline so such error is expected.
|
||
|
c.log.Debug("No updates")
|
||
|
return true
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
func (c *Conn) handleAuthKeyNotFound(ctx context.Context) error {
|
||
|
if c.session().ID == 0 {
|
||
|
// The 404 error can also be caused by zero session id.
|
||
|
// See https://github.com/gotd/td/issues/107
|
||
|
//
|
||
|
// We should recover from this in createAuthKey, but in general
|
||
|
// this code branch should be unreachable.
|
||
|
c.log.Warn("BUG: zero session id found")
|
||
|
}
|
||
|
c.log.Warn("Re-generating keys (server not found key that we provided)")
|
||
|
if err := c.createAuthKey(ctx); err != nil {
|
||
|
return errors.Wrap(err, "unable to create auth key")
|
||
|
}
|
||
|
c.log.Info("Re-created auth keys")
|
||
|
// Request will be retried by ack loop.
|
||
|
// Probably we can speed-up this.
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *Conn) readLoop(ctx context.Context) (err error) {
|
||
|
log := c.log.Named("read")
|
||
|
log.Debug("Read loop started")
|
||
|
defer func() {
|
||
|
l := log
|
||
|
if err != nil {
|
||
|
l = log.With(zap.NamedError("reason", err))
|
||
|
}
|
||
|
l.Debug("Read loop done")
|
||
|
}()
|
||
|
|
||
|
var (
|
||
|
// Last error encountered by consumeMessage.
|
||
|
lastErr atomic.Value
|
||
|
// To wait all spawned goroutines
|
||
|
handlers sync.WaitGroup
|
||
|
)
|
||
|
defer handlers.Wait()
|
||
|
|
||
|
for {
|
||
|
// We've tried multiple ways to reduce allocations via reusing buffer,
|
||
|
// but naive implementation induces high idle memory waste.
|
||
|
//
|
||
|
// Proper optimization will probably require total rework of bin.Buffer
|
||
|
// with sharded (by payload size?) pool that can be used after message
|
||
|
// size read (after readLen).
|
||
|
//
|
||
|
// Such optimization can introduce additional complexity overhead and
|
||
|
// is probably not worth it.
|
||
|
buf := &bin.Buffer{}
|
||
|
|
||
|
// Halting if consumeMessage encountered error.
|
||
|
// Should be something critical with crypto.
|
||
|
if err, ok := lastErr.Load().(error); ok && err != nil {
|
||
|
return errors.Wrap(err, "halting")
|
||
|
}
|
||
|
|
||
|
if err := c.conn.Recv(ctx, buf); err != nil {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return ctx.Err()
|
||
|
default:
|
||
|
if c.noUpdates(err) {
|
||
|
continue
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var protoErr *codec.ProtocolErr
|
||
|
if errors.As(err, &protoErr) && protoErr.Code == codec.CodeAuthKeyNotFound {
|
||
|
if err := c.handleAuthKeyNotFound(ctx); err != nil {
|
||
|
return errors.Wrap(err, "auth key not found")
|
||
|
}
|
||
|
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return errors.Wrap(ctx.Err(), "read loop")
|
||
|
default:
|
||
|
return errors.Wrap(err, "read")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
handlers.Add(1)
|
||
|
go func() {
|
||
|
defer handlers.Done()
|
||
|
|
||
|
// Spawning goroutine per incoming message to utilize as much
|
||
|
// resources as possible while keeping idle utilization low.
|
||
|
//
|
||
|
// The "worker" model was replaced by this due to idle utilization
|
||
|
// overhead, especially on multi-CPU systems with multiple running
|
||
|
// clients.
|
||
|
if err := c.consumeMessage(ctx, buf); err != nil {
|
||
|
log.Error("Failed to process message", zap.Error(err))
|
||
|
lastErr.Store(errors.Wrap(err, "consume"))
|
||
|
}
|
||
|
}()
|
||
|
}
|
||
|
}
|