You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
668 lines
15 KiB
668 lines
15 KiB
package redis |
|
|
|
import ( |
|
"context" |
|
"fmt" |
|
"strings" |
|
"sync" |
|
"time" |
|
|
|
"github.com/go-redis/redis/v8/internal" |
|
"github.com/go-redis/redis/v8/internal/pool" |
|
"github.com/go-redis/redis/v8/internal/proto" |
|
) |
|
|
|
// PubSub implements Pub/Sub commands as described in |
|
// http://redis.io/topics/pubsub. Message receiving is NOT safe |
|
// for concurrent use by multiple goroutines. |
|
// |
|
// PubSub automatically reconnects to Redis Server and resubscribes |
|
// to the channels in case of network errors. |
|
type PubSub struct { |
|
opt *Options |
|
|
|
newConn func(ctx context.Context, channels []string) (*pool.Conn, error) |
|
closeConn func(*pool.Conn) error |
|
|
|
mu sync.Mutex |
|
cn *pool.Conn |
|
channels map[string]struct{} |
|
patterns map[string]struct{} |
|
|
|
closed bool |
|
exit chan struct{} |
|
|
|
cmd *Cmd |
|
|
|
chOnce sync.Once |
|
msgCh *channel |
|
allCh *channel |
|
} |
|
|
|
func (c *PubSub) init() { |
|
c.exit = make(chan struct{}) |
|
} |
|
|
|
func (c *PubSub) String() string { |
|
channels := mapKeys(c.channels) |
|
channels = append(channels, mapKeys(c.patterns)...) |
|
return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", ")) |
|
} |
|
|
|
func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) { |
|
c.mu.Lock() |
|
cn, err := c.conn(ctx, nil) |
|
c.mu.Unlock() |
|
return cn, err |
|
} |
|
|
|
func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, error) { |
|
if c.closed { |
|
return nil, pool.ErrClosed |
|
} |
|
if c.cn != nil { |
|
return c.cn, nil |
|
} |
|
|
|
channels := mapKeys(c.channels) |
|
channels = append(channels, newChannels...) |
|
|
|
cn, err := c.newConn(ctx, channels) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
if err := c.resubscribe(ctx, cn); err != nil { |
|
_ = c.closeConn(cn) |
|
return nil, err |
|
} |
|
|
|
c.cn = cn |
|
return cn, nil |
|
} |
|
|
|
func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error { |
|
return cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { |
|
return writeCmd(wr, cmd) |
|
}) |
|
} |
|
|
|
func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error { |
|
var firstErr error |
|
|
|
if len(c.channels) > 0 { |
|
firstErr = c._subscribe(ctx, cn, "subscribe", mapKeys(c.channels)) |
|
} |
|
|
|
if len(c.patterns) > 0 { |
|
err := c._subscribe(ctx, cn, "psubscribe", mapKeys(c.patterns)) |
|
if err != nil && firstErr == nil { |
|
firstErr = err |
|
} |
|
} |
|
|
|
return firstErr |
|
} |
|
|
|
func mapKeys(m map[string]struct{}) []string { |
|
s := make([]string, len(m)) |
|
i := 0 |
|
for k := range m { |
|
s[i] = k |
|
i++ |
|
} |
|
return s |
|
} |
|
|
|
func (c *PubSub) _subscribe( |
|
ctx context.Context, cn *pool.Conn, redisCmd string, channels []string, |
|
) error { |
|
args := make([]interface{}, 0, 1+len(channels)) |
|
args = append(args, redisCmd) |
|
for _, channel := range channels { |
|
args = append(args, channel) |
|
} |
|
cmd := NewSliceCmd(ctx, args...) |
|
return c.writeCmd(ctx, cn, cmd) |
|
} |
|
|
|
func (c *PubSub) releaseConnWithLock( |
|
ctx context.Context, |
|
cn *pool.Conn, |
|
err error, |
|
allowTimeout bool, |
|
) { |
|
c.mu.Lock() |
|
c.releaseConn(ctx, cn, err, allowTimeout) |
|
c.mu.Unlock() |
|
} |
|
|
|
func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) { |
|
if c.cn != cn { |
|
return |
|
} |
|
if isBadConn(err, allowTimeout, c.opt.Addr) { |
|
c.reconnect(ctx, err) |
|
} |
|
} |
|
|
|
func (c *PubSub) reconnect(ctx context.Context, reason error) { |
|
_ = c.closeTheCn(reason) |
|
_, _ = c.conn(ctx, nil) |
|
} |
|
|
|
func (c *PubSub) closeTheCn(reason error) error { |
|
if c.cn == nil { |
|
return nil |
|
} |
|
if !c.closed { |
|
internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason) |
|
} |
|
err := c.closeConn(c.cn) |
|
c.cn = nil |
|
return err |
|
} |
|
|
|
func (c *PubSub) Close() error { |
|
c.mu.Lock() |
|
defer c.mu.Unlock() |
|
|
|
if c.closed { |
|
return pool.ErrClosed |
|
} |
|
c.closed = true |
|
close(c.exit) |
|
|
|
return c.closeTheCn(pool.ErrClosed) |
|
} |
|
|
|
// Subscribe the client to the specified channels. It returns |
|
// empty subscription if there are no channels. |
|
func (c *PubSub) Subscribe(ctx context.Context, channels ...string) error { |
|
c.mu.Lock() |
|
defer c.mu.Unlock() |
|
|
|
err := c.subscribe(ctx, "subscribe", channels...) |
|
if c.channels == nil { |
|
c.channels = make(map[string]struct{}) |
|
} |
|
for _, s := range channels { |
|
c.channels[s] = struct{}{} |
|
} |
|
return err |
|
} |
|
|
|
// PSubscribe the client to the given patterns. It returns |
|
// empty subscription if there are no patterns. |
|
func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error { |
|
c.mu.Lock() |
|
defer c.mu.Unlock() |
|
|
|
err := c.subscribe(ctx, "psubscribe", patterns...) |
|
if c.patterns == nil { |
|
c.patterns = make(map[string]struct{}) |
|
} |
|
for _, s := range patterns { |
|
c.patterns[s] = struct{}{} |
|
} |
|
return err |
|
} |
|
|
|
// Unsubscribe the client from the given channels, or from all of |
|
// them if none is given. |
|
func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error { |
|
c.mu.Lock() |
|
defer c.mu.Unlock() |
|
|
|
for _, channel := range channels { |
|
delete(c.channels, channel) |
|
} |
|
err := c.subscribe(ctx, "unsubscribe", channels...) |
|
return err |
|
} |
|
|
|
// PUnsubscribe the client from the given patterns, or from all of |
|
// them if none is given. |
|
func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error { |
|
c.mu.Lock() |
|
defer c.mu.Unlock() |
|
|
|
for _, pattern := range patterns { |
|
delete(c.patterns, pattern) |
|
} |
|
err := c.subscribe(ctx, "punsubscribe", patterns...) |
|
return err |
|
} |
|
|
|
func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error { |
|
cn, err := c.conn(ctx, channels) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
err = c._subscribe(ctx, cn, redisCmd, channels) |
|
c.releaseConn(ctx, cn, err, false) |
|
return err |
|
} |
|
|
|
func (c *PubSub) Ping(ctx context.Context, payload ...string) error { |
|
args := []interface{}{"ping"} |
|
if len(payload) == 1 { |
|
args = append(args, payload[0]) |
|
} |
|
cmd := NewCmd(ctx, args...) |
|
|
|
c.mu.Lock() |
|
defer c.mu.Unlock() |
|
|
|
cn, err := c.conn(ctx, nil) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
err = c.writeCmd(ctx, cn, cmd) |
|
c.releaseConn(ctx, cn, err, false) |
|
return err |
|
} |
|
|
|
// Subscription received after a successful subscription to channel. |
|
type Subscription struct { |
|
// Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe". |
|
Kind string |
|
// Channel name we have subscribed to. |
|
Channel string |
|
// Number of channels we are currently subscribed to. |
|
Count int |
|
} |
|
|
|
func (m *Subscription) String() string { |
|
return fmt.Sprintf("%s: %s", m.Kind, m.Channel) |
|
} |
|
|
|
// Message received as result of a PUBLISH command issued by another client. |
|
type Message struct { |
|
Channel string |
|
Pattern string |
|
Payload string |
|
PayloadSlice []string |
|
} |
|
|
|
func (m *Message) String() string { |
|
return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload) |
|
} |
|
|
|
// Pong received as result of a PING command issued by another client. |
|
type Pong struct { |
|
Payload string |
|
} |
|
|
|
func (p *Pong) String() string { |
|
if p.Payload != "" { |
|
return fmt.Sprintf("Pong<%s>", p.Payload) |
|
} |
|
return "Pong" |
|
} |
|
|
|
func (c *PubSub) newMessage(reply interface{}) (interface{}, error) { |
|
switch reply := reply.(type) { |
|
case string: |
|
return &Pong{ |
|
Payload: reply, |
|
}, nil |
|
case []interface{}: |
|
switch kind := reply[0].(string); kind { |
|
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe": |
|
// Can be nil in case of "unsubscribe". |
|
channel, _ := reply[1].(string) |
|
return &Subscription{ |
|
Kind: kind, |
|
Channel: channel, |
|
Count: int(reply[2].(int64)), |
|
}, nil |
|
case "message": |
|
switch payload := reply[2].(type) { |
|
case string: |
|
return &Message{ |
|
Channel: reply[1].(string), |
|
Payload: payload, |
|
}, nil |
|
case []interface{}: |
|
ss := make([]string, len(payload)) |
|
for i, s := range payload { |
|
ss[i] = s.(string) |
|
} |
|
return &Message{ |
|
Channel: reply[1].(string), |
|
PayloadSlice: ss, |
|
}, nil |
|
default: |
|
return nil, fmt.Errorf("redis: unsupported pubsub message payload: %T", payload) |
|
} |
|
case "pmessage": |
|
return &Message{ |
|
Pattern: reply[1].(string), |
|
Channel: reply[2].(string), |
|
Payload: reply[3].(string), |
|
}, nil |
|
case "pong": |
|
return &Pong{ |
|
Payload: reply[1].(string), |
|
}, nil |
|
default: |
|
return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind) |
|
} |
|
default: |
|
return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply) |
|
} |
|
} |
|
|
|
// ReceiveTimeout acts like Receive but returns an error if message |
|
// is not received in time. This is low-level API and in most cases |
|
// Channel should be used instead. |
|
func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (interface{}, error) { |
|
if c.cmd == nil { |
|
c.cmd = NewCmd(ctx) |
|
} |
|
|
|
// Don't hold the lock to allow subscriptions and pings. |
|
|
|
cn, err := c.connWithLock(ctx) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error { |
|
return c.cmd.readReply(rd) |
|
}) |
|
|
|
c.releaseConnWithLock(ctx, cn, err, timeout > 0) |
|
|
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
return c.newMessage(c.cmd.Val()) |
|
} |
|
|
|
// Receive returns a message as a Subscription, Message, Pong or error. |
|
// See PubSub example for details. This is low-level API and in most cases |
|
// Channel should be used instead. |
|
func (c *PubSub) Receive(ctx context.Context) (interface{}, error) { |
|
return c.ReceiveTimeout(ctx, 0) |
|
} |
|
|
|
// ReceiveMessage returns a Message or error ignoring Subscription and Pong |
|
// messages. This is low-level API and in most cases Channel should be used |
|
// instead. |
|
func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) { |
|
for { |
|
msg, err := c.Receive(ctx) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
switch msg := msg.(type) { |
|
case *Subscription: |
|
// Ignore. |
|
case *Pong: |
|
// Ignore. |
|
case *Message: |
|
return msg, nil |
|
default: |
|
err := fmt.Errorf("redis: unknown message: %T", msg) |
|
return nil, err |
|
} |
|
} |
|
} |
|
|
|
func (c *PubSub) getContext() context.Context { |
|
if c.cmd != nil { |
|
return c.cmd.ctx |
|
} |
|
return context.Background() |
|
} |
|
|
|
//------------------------------------------------------------------------------ |
|
|
|
// Channel returns a Go channel for concurrently receiving messages. |
|
// The channel is closed together with the PubSub. If the Go channel |
|
// is blocked full for 30 seconds the message is dropped. |
|
// Receive* APIs can not be used after channel is created. |
|
// |
|
// go-redis periodically sends ping messages to test connection health |
|
// and re-subscribes if ping can not not received for 30 seconds. |
|
func (c *PubSub) Channel(opts ...ChannelOption) <-chan *Message { |
|
c.chOnce.Do(func() { |
|
c.msgCh = newChannel(c, opts...) |
|
c.msgCh.initMsgChan() |
|
}) |
|
if c.msgCh == nil { |
|
err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions") |
|
panic(err) |
|
} |
|
return c.msgCh.msgCh |
|
} |
|
|
|
// ChannelSize is like Channel, but creates a Go channel |
|
// with specified buffer size. |
|
// |
|
// Deprecated: use Channel(WithChannelSize(size)), remove in v9. |
|
func (c *PubSub) ChannelSize(size int) <-chan *Message { |
|
return c.Channel(WithChannelSize(size)) |
|
} |
|
|
|
// ChannelWithSubscriptions is like Channel, but message type can be either |
|
// *Subscription or *Message. Subscription messages can be used to detect |
|
// reconnections. |
|
// |
|
// ChannelWithSubscriptions can not be used together with Channel or ChannelSize. |
|
func (c *PubSub) ChannelWithSubscriptions(_ context.Context, size int) <-chan interface{} { |
|
c.chOnce.Do(func() { |
|
c.allCh = newChannel(c, WithChannelSize(size)) |
|
c.allCh.initAllChan() |
|
}) |
|
if c.allCh == nil { |
|
err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel") |
|
panic(err) |
|
} |
|
return c.allCh.allCh |
|
} |
|
|
|
type ChannelOption func(c *channel) |
|
|
|
// WithChannelSize specifies the Go chan size that is used to buffer incoming messages. |
|
// |
|
// The default is 100 messages. |
|
func WithChannelSize(size int) ChannelOption { |
|
return func(c *channel) { |
|
c.chanSize = size |
|
} |
|
} |
|
|
|
// WithChannelHealthCheckInterval specifies the health check interval. |
|
// PubSub will ping Redis Server if it does not receive any messages within the interval. |
|
// To disable health check, use zero interval. |
|
// |
|
// The default is 3 seconds. |
|
func WithChannelHealthCheckInterval(d time.Duration) ChannelOption { |
|
return func(c *channel) { |
|
c.checkInterval = d |
|
} |
|
} |
|
|
|
// WithChannelSendTimeout specifies the channel send timeout after which |
|
// the message is dropped. |
|
// |
|
// The default is 60 seconds. |
|
func WithChannelSendTimeout(d time.Duration) ChannelOption { |
|
return func(c *channel) { |
|
c.chanSendTimeout = d |
|
} |
|
} |
|
|
|
type channel struct { |
|
pubSub *PubSub |
|
|
|
msgCh chan *Message |
|
allCh chan interface{} |
|
ping chan struct{} |
|
|
|
chanSize int |
|
chanSendTimeout time.Duration |
|
checkInterval time.Duration |
|
} |
|
|
|
func newChannel(pubSub *PubSub, opts ...ChannelOption) *channel { |
|
c := &channel{ |
|
pubSub: pubSub, |
|
|
|
chanSize: 100, |
|
chanSendTimeout: time.Minute, |
|
checkInterval: 3 * time.Second, |
|
} |
|
for _, opt := range opts { |
|
opt(c) |
|
} |
|
if c.checkInterval > 0 { |
|
c.initHealthCheck() |
|
} |
|
return c |
|
} |
|
|
|
func (c *channel) initHealthCheck() { |
|
ctx := context.TODO() |
|
c.ping = make(chan struct{}, 1) |
|
|
|
go func() { |
|
timer := time.NewTimer(time.Minute) |
|
timer.Stop() |
|
|
|
for { |
|
timer.Reset(c.checkInterval) |
|
select { |
|
case <-c.ping: |
|
if !timer.Stop() { |
|
<-timer.C |
|
} |
|
case <-timer.C: |
|
if pingErr := c.pubSub.Ping(ctx); pingErr != nil { |
|
c.pubSub.mu.Lock() |
|
c.pubSub.reconnect(ctx, pingErr) |
|
c.pubSub.mu.Unlock() |
|
} |
|
case <-c.pubSub.exit: |
|
return |
|
} |
|
} |
|
}() |
|
} |
|
|
|
// initMsgChan must be in sync with initAllChan. |
|
func (c *channel) initMsgChan() { |
|
ctx := context.TODO() |
|
c.msgCh = make(chan *Message, c.chanSize) |
|
|
|
go func() { |
|
timer := time.NewTimer(time.Minute) |
|
timer.Stop() |
|
|
|
var errCount int |
|
for { |
|
msg, err := c.pubSub.Receive(ctx) |
|
if err != nil { |
|
if err == pool.ErrClosed { |
|
close(c.msgCh) |
|
return |
|
} |
|
if errCount > 0 { |
|
time.Sleep(100 * time.Millisecond) |
|
} |
|
errCount++ |
|
continue |
|
} |
|
|
|
errCount = 0 |
|
|
|
// Any message is as good as a ping. |
|
select { |
|
case c.ping <- struct{}{}: |
|
default: |
|
} |
|
|
|
switch msg := msg.(type) { |
|
case *Subscription: |
|
// Ignore. |
|
case *Pong: |
|
// Ignore. |
|
case *Message: |
|
timer.Reset(c.chanSendTimeout) |
|
select { |
|
case c.msgCh <- msg: |
|
if !timer.Stop() { |
|
<-timer.C |
|
} |
|
case <-timer.C: |
|
internal.Logger.Printf( |
|
ctx, "redis: %s channel is full for %s (message is dropped)", |
|
c, c.chanSendTimeout) |
|
} |
|
default: |
|
internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) |
|
} |
|
} |
|
}() |
|
} |
|
|
|
// initAllChan must be in sync with initMsgChan. |
|
func (c *channel) initAllChan() { |
|
ctx := context.TODO() |
|
c.allCh = make(chan interface{}, c.chanSize) |
|
|
|
go func() { |
|
timer := time.NewTimer(time.Minute) |
|
timer.Stop() |
|
|
|
var errCount int |
|
for { |
|
msg, err := c.pubSub.Receive(ctx) |
|
if err != nil { |
|
if err == pool.ErrClosed { |
|
close(c.allCh) |
|
return |
|
} |
|
if errCount > 0 { |
|
time.Sleep(100 * time.Millisecond) |
|
} |
|
errCount++ |
|
continue |
|
} |
|
|
|
errCount = 0 |
|
|
|
// Any message is as good as a ping. |
|
select { |
|
case c.ping <- struct{}{}: |
|
default: |
|
} |
|
|
|
switch msg := msg.(type) { |
|
case *Pong: |
|
// Ignore. |
|
case *Subscription, *Message: |
|
timer.Reset(c.chanSendTimeout) |
|
select { |
|
case c.allCh <- msg: |
|
if !timer.Stop() { |
|
<-timer.C |
|
} |
|
case <-timer.C: |
|
internal.Logger.Printf( |
|
ctx, "redis: %s channel is full for %s (message is dropped)", |
|
c, c.chanSendTimeout) |
|
} |
|
default: |
|
internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) |
|
} |
|
} |
|
}() |
|
}
|
|
|