You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

213 lines
5.8 KiB

package mtproto
import (
"context"
"net"
"sync"
"sync/atomic"
"time"
"github.com/go-faster/errors"
"go.uber.org/zap"
"github.com/gotd/td/bin"
"github.com/gotd/td/internal/crypto"
"github.com/gotd/td/internal/proto"
"github.com/gotd/td/internal/proto/codec"
)
// https://core.telegram.org/mtproto/description#message-identifier-msg-id
// A message is rejected over 300 seconds after it is created or 30 seconds
// before it is created (this is needed to protect from replay attacks).
const (
maxPast = time.Second * 300
maxFuture = time.Second * 30
)
// errRejected is returned on invalid message that should not be processed.
var errRejected = errors.New("message rejected")
func checkMessageID(now time.Time, rawID int64) error {
id := proto.MessageID(rawID)
// Check that message is from server.
switch id.Type() {
case proto.MessageFromServer, proto.MessageServerResponse:
// Valid.
default:
return errors.Wrapf(errRejected, "unexpected type %s", id.Type())
}
created := id.Time()
if created.Before(now) && now.Sub(created) > maxPast {
return errors.Wrap(errRejected, "created too far in past")
}
if created.Sub(now) > maxFuture {
return errors.Wrap(errRejected, "created too far in future")
}
return nil
}
func (c *Conn) decryptMessage(b *bin.Buffer) (*crypto.EncryptedMessageData, error) {
session := c.session()
msg, err := c.cipher.DecryptFromBuffer(session.Key, b)
if err != nil {
return nil, errors.Wrap(err, "decrypt")
}
// Validating message. This protects from replay attacks.
if msg.SessionID != session.ID {
return nil, errors.Wrapf(errRejected, "invalid session (got %d, expected %d)", msg.SessionID, session.ID)
}
if err := checkMessageID(c.clock.Now(), msg.MessageID); err != nil {
return nil, errors.Wrapf(err, "bad message id %d", msg.MessageID)
}
if !c.messageIDBuf.Consume(msg.MessageID) {
return nil, errors.Wrapf(errRejected, "duplicate or too low message id %d", msg.MessageID)
}
return msg, nil
}
func (c *Conn) consumeMessage(ctx context.Context, buf *bin.Buffer) error {
msg, err := c.decryptMessage(buf)
if errors.Is(err, errRejected) {
c.log.Warn("Ignoring rejected message", zap.Error(err))
return nil
}
if err != nil {
return errors.Wrap(err, "consume message")
}
if err := c.handleMessage(msg.MessageID, &bin.Buffer{Buf: msg.Data()}); err != nil {
// Probably we can return here, but this will shutdown whole
// connection which can be unexpected.
c.log.Warn("Error while handling message", zap.Error(err))
// Sending acknowledge even on error. Client should restore
// from missing updates via explicit pts check and getDiff call.
}
needAck := (msg.SeqNo & 0x01) != 0
if needAck {
select {
case <-ctx.Done():
return ctx.Err()
case c.ackSendChan <- msg.MessageID:
}
}
return nil
}
func (c *Conn) noUpdates(err error) bool {
// Checking for read timeout.
var syscall *net.OpError
if errors.As(err, &syscall) && syscall.Timeout() {
// We call SetReadDeadline so such error is expected.
c.log.Debug("No updates")
return true
}
return false
}
func (c *Conn) handleAuthKeyNotFound(ctx context.Context) error {
if c.session().ID == 0 {
// The 404 error can also be caused by zero session id.
// See https://github.com/gotd/td/issues/107
//
// We should recover from this in createAuthKey, but in general
// this code branch should be unreachable.
c.log.Warn("BUG: zero session id found")
}
c.log.Warn("Re-generating keys (server not found key that we provided)")
if err := c.createAuthKey(ctx); err != nil {
return errors.Wrap(err, "unable to create auth key")
}
c.log.Info("Re-created auth keys")
// Request will be retried by ack loop.
// Probably we can speed-up this.
return nil
}
func (c *Conn) readLoop(ctx context.Context) (err error) {
log := c.log.Named("read")
log.Debug("Read loop started")
defer func() {
l := log
if err != nil {
l = log.With(zap.NamedError("reason", err))
}
l.Debug("Read loop done")
}()
var (
// Last error encountered by consumeMessage.
lastErr atomic.Value
// To wait all spawned goroutines
handlers sync.WaitGroup
)
defer handlers.Wait()
for {
// We've tried multiple ways to reduce allocations via reusing buffer,
// but naive implementation induces high idle memory waste.
//
// Proper optimization will probably require total rework of bin.Buffer
// with sharded (by payload size?) pool that can be used after message
// size read (after readLen).
//
// Such optimization can introduce additional complexity overhead and
// is probably not worth it.
buf := &bin.Buffer{}
// Halting if consumeMessage encountered error.
// Should be something critical with crypto.
if err, ok := lastErr.Load().(error); ok && err != nil {
return errors.Wrap(err, "halting")
}
if err := c.conn.Recv(ctx, buf); err != nil {
select {
case <-ctx.Done():
return ctx.Err()
default:
if c.noUpdates(err) {
continue
}
}
var protoErr *codec.ProtocolErr
if errors.As(err, &protoErr) && protoErr.Code == codec.CodeAuthKeyNotFound {
if err := c.handleAuthKeyNotFound(ctx); err != nil {
return errors.Wrap(err, "auth key not found")
}
continue
}
select {
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "read loop")
default:
return errors.Wrap(err, "read")
}
}
handlers.Add(1)
go func() {
defer handlers.Done()
// Spawning goroutine per incoming message to utilize as much
// resources as possible while keeping idle utilization low.
//
// The "worker" model was replaced by this due to idle utilization
// overhead, especially on multi-CPU systems with multiple running
// clients.
if err := c.consumeMessage(ctx, buf); err != nil {
log.Error("Failed to process message", zap.Error(err))
lastErr.Store(errors.Wrap(err, "consume"))
}
}()
}
}