You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
281 lines
6.4 KiB
281 lines
6.4 KiB
package rpc |
|
|
|
import ( |
|
"context" |
|
"sync" |
|
"sync/atomic" |
|
"time" |
|
|
|
"github.com/go-faster/errors" |
|
"go.uber.org/zap" |
|
|
|
"github.com/gotd/td/bin" |
|
"github.com/gotd/td/clock" |
|
) |
|
|
|
// Engine handles RPC requests. |
|
type Engine struct { |
|
send Send |
|
drop DropHandler |
|
|
|
mux sync.Mutex |
|
rpc map[int64]func(*bin.Buffer, error) error |
|
ack map[int64]chan struct{} |
|
|
|
clock clock.Clock |
|
log *zap.Logger |
|
retryInterval time.Duration |
|
maxRetries int |
|
|
|
// Canceling pending requests in ForceClose. |
|
reqCtx context.Context |
|
reqCancel context.CancelFunc |
|
|
|
wg sync.WaitGroup |
|
closed uint32 |
|
} |
|
|
|
// New creates new rpc Engine. |
|
func New(send Send, cfg Options) *Engine { |
|
cfg.setDefaults() |
|
|
|
cfg.Logger.Info("Initialized", |
|
zap.Duration("retry_interval", cfg.RetryInterval), |
|
zap.Int("max_retries", cfg.MaxRetries), |
|
) |
|
|
|
reqCtx, reqCancel := context.WithCancel(context.Background()) |
|
return &Engine{ |
|
rpc: map[int64]func(*bin.Buffer, error) error{}, |
|
ack: map[int64]chan struct{}{}, |
|
|
|
send: send, |
|
drop: cfg.DropHandler, |
|
|
|
log: cfg.Logger, |
|
maxRetries: cfg.MaxRetries, |
|
retryInterval: cfg.RetryInterval, |
|
clock: cfg.Clock, |
|
|
|
reqCtx: reqCtx, |
|
reqCancel: reqCancel, |
|
} |
|
} |
|
|
|
// Request represents client RPC request. |
|
type Request struct { |
|
MsgID int64 |
|
SeqNo int32 |
|
Input bin.Encoder |
|
Output bin.Decoder |
|
} |
|
|
|
// Do sends request to server and blocks until response is received, performing |
|
// multiple retries if needed. |
|
func (e *Engine) Do(ctx context.Context, req Request) error { |
|
if e.isClosed() { |
|
return ErrEngineClosed |
|
} |
|
|
|
e.wg.Add(1) |
|
defer e.wg.Done() |
|
|
|
retryCtx, retryClose := context.WithCancel(ctx) |
|
defer retryClose() |
|
|
|
log := e.log.With(zap.Int64("msg_id", req.MsgID)) |
|
log.Debug("Do called") |
|
|
|
done := make(chan struct{}) |
|
|
|
var ( |
|
// Handler result. |
|
resultErr error |
|
// Needed to prevent multiple handler calls. |
|
handlerCalled uint32 |
|
) |
|
|
|
handler := func(rpcBuff *bin.Buffer, rpcErr error) error { |
|
log.Debug("Handler called") |
|
|
|
if ok := atomic.CompareAndSwapUint32(&handlerCalled, 0, 1); !ok { |
|
log.Warn("Handler already called") |
|
|
|
return errors.New("handler already called") |
|
} |
|
|
|
defer retryClose() |
|
defer close(done) |
|
|
|
if rpcErr != nil { |
|
resultErr = rpcErr |
|
return nil |
|
} |
|
|
|
resultErr = req.Output.Decode(rpcBuff) |
|
return resultErr |
|
} |
|
|
|
// Setting callback that will be called if message is received. |
|
e.mux.Lock() |
|
e.rpc[req.MsgID] = handler |
|
e.mux.Unlock() |
|
|
|
defer func() { |
|
// Ensuring that callback can't be called after function return. |
|
e.mux.Lock() |
|
delete(e.rpc, req.MsgID) |
|
e.mux.Unlock() |
|
}() |
|
|
|
// Start retrying. |
|
sent, err := e.retryUntilAck(retryCtx, req) |
|
if err != nil && !errors.Is(err, retryCtx.Err()) { |
|
// If the retryCtx was canceled, then one of two things happened: |
|
// 1. User canceled the parent context. |
|
// 2. The RPC result came and callback canceled retryCtx. |
|
// |
|
// If this is not a Context’s error, most likely we did not receive ack |
|
// and exceeded the limit of attempts to send a request, |
|
// or could not write data to the connection, so we return an error. |
|
return errors.Wrap(err, "retryUntilAck") |
|
} |
|
|
|
select { |
|
case <-ctx.Done(): |
|
if !sent { |
|
return ctx.Err() |
|
} |
|
|
|
// Set nop callback because server will respond with 'RpcDropAnswer' instead of expected result. |
|
// |
|
// NOTE(ccln): We can decode 'RpcDropAnswer' here but I see no reason to do this |
|
// because it will also come as a response to 'RPCDropAnswerRequest'. |
|
// |
|
// https://core.telegram.org/mtproto/service_messages#cancellation-of-an-rpc-query |
|
e.mux.Lock() |
|
e.rpc[req.MsgID] = func(b *bin.Buffer, e error) error { return nil } |
|
e.mux.Unlock() |
|
|
|
if err := e.drop(req); err != nil { |
|
log.Info("Failed to drop request", zap.Error(err)) |
|
return ctx.Err() |
|
} |
|
|
|
log.Debug("Request dropped") |
|
return ctx.Err() |
|
case <-e.reqCtx.Done(): |
|
return errors.Wrap(e.reqCtx.Err(), "engine forcibly closed") |
|
case <-done: |
|
return resultErr |
|
} |
|
} |
|
|
|
// retryUntilAck resends the request to the server until request is |
|
// acknowledged. |
|
// |
|
// Returns nil if acknowledge was received or error otherwise. |
|
func (e *Engine) retryUntilAck(ctx context.Context, req Request) (sent bool, err error) { |
|
ctx, cancel := context.WithCancel(ctx) |
|
defer cancel() |
|
|
|
var ( |
|
ackChan = e.waitAck(req.MsgID) |
|
retries = 0 |
|
log = e.log.Named("retry").With(zap.Int64("msg_id", req.MsgID)) |
|
) |
|
|
|
defer e.removeAck(req.MsgID) |
|
|
|
// Encoding request. |
|
if err := e.send(ctx, req.MsgID, req.SeqNo, req.Input); err != nil { |
|
return false, errors.Wrap(err, "send") |
|
} |
|
|
|
loop := func() error { |
|
timer := e.clock.Timer(e.retryInterval) |
|
defer clock.StopTimer(timer) |
|
|
|
for { |
|
select { |
|
case <-ctx.Done(): |
|
return ctx.Err() |
|
case <-e.reqCtx.Done(): |
|
return errors.Wrap(e.reqCtx.Err(), "engine forcibly closed") |
|
case <-ackChan: |
|
log.Debug("Acknowledged") |
|
return nil |
|
case <-timer.C(): |
|
timer.Reset(e.retryInterval) |
|
|
|
log.Debug("Acknowledge timed out, performing retry") |
|
if err := e.send(ctx, req.MsgID, req.SeqNo, req.Input); err != nil { |
|
if errors.Is(err, context.Canceled) { |
|
return nil |
|
} |
|
|
|
log.Error("Retry failed", zap.Error(err)) |
|
return err |
|
} |
|
|
|
retries++ |
|
if retries >= e.maxRetries { |
|
log.Error("Retry limit reached", zap.Int64("msg_id", req.MsgID)) |
|
return &RetryLimitReachedErr{ |
|
Retries: retries, |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
return true, loop() |
|
} |
|
|
|
// NotifyResult notifies engine about received RPC response. |
|
func (e *Engine) NotifyResult(msgID int64, b *bin.Buffer) error { |
|
e.mux.Lock() |
|
fn, ok := e.rpc[msgID] |
|
e.mux.Unlock() |
|
if !ok { |
|
e.log.Warn("rpc callback not set", zap.Int64("msg_id", msgID)) |
|
return nil |
|
} |
|
|
|
return fn(b, nil) |
|
} |
|
|
|
// NotifyError notifies engine about received RPC error. |
|
func (e *Engine) NotifyError(msgID int64, rpcErr error) { |
|
e.mux.Lock() |
|
fn, ok := e.rpc[msgID] |
|
e.mux.Unlock() |
|
if !ok { |
|
e.log.Warn("rpc callback not set", zap.Int64("msg_id", msgID)) |
|
return |
|
} |
|
|
|
// Callback with rpcError always return nil. |
|
_ = fn(nil, rpcErr) |
|
} |
|
|
|
func (e *Engine) isClosed() bool { |
|
return atomic.LoadUint32(&e.closed) == 1 |
|
} |
|
|
|
// Close gracefully closes the engine. |
|
// All pending requests will be awaited. |
|
// All Do method calls of closed engine will return ErrEngineClosed error. |
|
func (e *Engine) Close() { |
|
atomic.StoreUint32(&e.closed, 1) |
|
e.log.Info("Close called") |
|
e.wg.Wait() |
|
} |
|
|
|
// ForceClose forcibly closes the engine. |
|
// All pending requests will be canceled. |
|
// All Do method calls of closed engine will return ErrEngineClosed error. |
|
func (e *Engine) ForceClose() { |
|
e.reqCancel() |
|
e.Close() |
|
}
|
|
|