You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
266 lines
5.3 KiB
266 lines
5.3 KiB
3 years ago
|
// +build !js
|
||
|
|
||
|
package websocket
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"runtime"
|
||
|
"strconv"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
)
|
||
|
|
||
|
// Conn represents a WebSocket connection.
|
||
|
// All methods may be called concurrently except for Reader and Read.
|
||
|
//
|
||
|
// You must always read from the connection. Otherwise control
|
||
|
// frames will not be handled. See Reader and CloseRead.
|
||
|
//
|
||
|
// Be sure to call Close on the connection when you
|
||
|
// are finished with it to release associated resources.
|
||
|
//
|
||
|
// On any error from any method, the connection is closed
|
||
|
// with an appropriate reason.
|
||
|
type Conn struct {
|
||
|
subprotocol string
|
||
|
rwc io.ReadWriteCloser
|
||
|
client bool
|
||
|
copts *compressionOptions
|
||
|
flateThreshold int
|
||
|
br *bufio.Reader
|
||
|
bw *bufio.Writer
|
||
|
|
||
|
readTimeout chan context.Context
|
||
|
writeTimeout chan context.Context
|
||
|
|
||
|
// Read state.
|
||
|
readMu *mu
|
||
|
readHeaderBuf [8]byte
|
||
|
readControlBuf [maxControlPayload]byte
|
||
|
msgReader *msgReader
|
||
|
readCloseFrameErr error
|
||
|
|
||
|
// Write state.
|
||
|
msgWriterState *msgWriterState
|
||
|
writeFrameMu *mu
|
||
|
writeBuf []byte
|
||
|
writeHeaderBuf [8]byte
|
||
|
writeHeader header
|
||
|
|
||
|
closed chan struct{}
|
||
|
closeMu sync.Mutex
|
||
|
closeErr error
|
||
|
wroteClose bool
|
||
|
|
||
|
pingCounter int32
|
||
|
activePingsMu sync.Mutex
|
||
|
activePings map[string]chan<- struct{}
|
||
|
}
|
||
|
|
||
|
type connConfig struct {
|
||
|
subprotocol string
|
||
|
rwc io.ReadWriteCloser
|
||
|
client bool
|
||
|
copts *compressionOptions
|
||
|
flateThreshold int
|
||
|
|
||
|
br *bufio.Reader
|
||
|
bw *bufio.Writer
|
||
|
}
|
||
|
|
||
|
func newConn(cfg connConfig) *Conn {
|
||
|
c := &Conn{
|
||
|
subprotocol: cfg.subprotocol,
|
||
|
rwc: cfg.rwc,
|
||
|
client: cfg.client,
|
||
|
copts: cfg.copts,
|
||
|
flateThreshold: cfg.flateThreshold,
|
||
|
|
||
|
br: cfg.br,
|
||
|
bw: cfg.bw,
|
||
|
|
||
|
readTimeout: make(chan context.Context),
|
||
|
writeTimeout: make(chan context.Context),
|
||
|
|
||
|
closed: make(chan struct{}),
|
||
|
activePings: make(map[string]chan<- struct{}),
|
||
|
}
|
||
|
|
||
|
c.readMu = newMu(c)
|
||
|
c.writeFrameMu = newMu(c)
|
||
|
|
||
|
c.msgReader = newMsgReader(c)
|
||
|
|
||
|
c.msgWriterState = newMsgWriterState(c)
|
||
|
if c.client {
|
||
|
c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
|
||
|
}
|
||
|
|
||
|
if c.flate() && c.flateThreshold == 0 {
|
||
|
c.flateThreshold = 128
|
||
|
if !c.msgWriterState.flateContextTakeover() {
|
||
|
c.flateThreshold = 512
|
||
|
}
|
||
|
}
|
||
|
|
||
|
runtime.SetFinalizer(c, func(c *Conn) {
|
||
|
c.close(errors.New("connection garbage collected"))
|
||
|
})
|
||
|
|
||
|
go c.timeoutLoop()
|
||
|
|
||
|
return c
|
||
|
}
|
||
|
|
||
|
// Subprotocol returns the negotiated subprotocol.
|
||
|
// An empty string means the default protocol.
|
||
|
func (c *Conn) Subprotocol() string {
|
||
|
return c.subprotocol
|
||
|
}
|
||
|
|
||
|
func (c *Conn) close(err error) {
|
||
|
c.closeMu.Lock()
|
||
|
defer c.closeMu.Unlock()
|
||
|
|
||
|
if c.isClosed() {
|
||
|
return
|
||
|
}
|
||
|
c.setCloseErrLocked(err)
|
||
|
close(c.closed)
|
||
|
runtime.SetFinalizer(c, nil)
|
||
|
|
||
|
// Have to close after c.closed is closed to ensure any goroutine that wakes up
|
||
|
// from the connection being closed also sees that c.closed is closed and returns
|
||
|
// closeErr.
|
||
|
c.rwc.Close()
|
||
|
|
||
|
go func() {
|
||
|
c.msgWriterState.close()
|
||
|
|
||
|
c.msgReader.close()
|
||
|
}()
|
||
|
}
|
||
|
|
||
|
func (c *Conn) timeoutLoop() {
|
||
|
readCtx := context.Background()
|
||
|
writeCtx := context.Background()
|
||
|
|
||
|
for {
|
||
|
select {
|
||
|
case <-c.closed:
|
||
|
return
|
||
|
|
||
|
case writeCtx = <-c.writeTimeout:
|
||
|
case readCtx = <-c.readTimeout:
|
||
|
|
||
|
case <-readCtx.Done():
|
||
|
c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
|
||
|
go c.writeError(StatusPolicyViolation, errors.New("timed out"))
|
||
|
case <-writeCtx.Done():
|
||
|
c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (c *Conn) flate() bool {
|
||
|
return c.copts != nil
|
||
|
}
|
||
|
|
||
|
// Ping sends a ping to the peer and waits for a pong.
|
||
|
// Use this to measure latency or ensure the peer is responsive.
|
||
|
// Ping must be called concurrently with Reader as it does
|
||
|
// not read from the connection but instead waits for a Reader call
|
||
|
// to read the pong.
|
||
|
//
|
||
|
// TCP Keepalives should suffice for most use cases.
|
||
|
func (c *Conn) Ping(ctx context.Context) error {
|
||
|
p := atomic.AddInt32(&c.pingCounter, 1)
|
||
|
|
||
|
err := c.ping(ctx, strconv.Itoa(int(p)))
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to ping: %w", err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *Conn) ping(ctx context.Context, p string) error {
|
||
|
pong := make(chan struct{}, 1)
|
||
|
|
||
|
c.activePingsMu.Lock()
|
||
|
c.activePings[p] = pong
|
||
|
c.activePingsMu.Unlock()
|
||
|
|
||
|
defer func() {
|
||
|
c.activePingsMu.Lock()
|
||
|
delete(c.activePings, p)
|
||
|
c.activePingsMu.Unlock()
|
||
|
}()
|
||
|
|
||
|
err := c.writeControl(ctx, opPing, []byte(p))
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case <-c.closed:
|
||
|
return c.closeErr
|
||
|
case <-ctx.Done():
|
||
|
err := fmt.Errorf("failed to wait for pong: %w", ctx.Err())
|
||
|
c.close(err)
|
||
|
return err
|
||
|
case <-pong:
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type mu struct {
|
||
|
c *Conn
|
||
|
ch chan struct{}
|
||
|
}
|
||
|
|
||
|
func newMu(c *Conn) *mu {
|
||
|
return &mu{
|
||
|
c: c,
|
||
|
ch: make(chan struct{}, 1),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (m *mu) forceLock() {
|
||
|
m.ch <- struct{}{}
|
||
|
}
|
||
|
|
||
|
func (m *mu) lock(ctx context.Context) error {
|
||
|
select {
|
||
|
case <-m.c.closed:
|
||
|
return m.c.closeErr
|
||
|
case <-ctx.Done():
|
||
|
err := fmt.Errorf("failed to acquire lock: %w", ctx.Err())
|
||
|
m.c.close(err)
|
||
|
return err
|
||
|
case m.ch <- struct{}{}:
|
||
|
// To make sure the connection is certainly alive.
|
||
|
// As it's possible the send on m.ch was selected
|
||
|
// over the receive on closed.
|
||
|
select {
|
||
|
case <-m.c.closed:
|
||
|
// Make sure to release.
|
||
|
m.unlock()
|
||
|
return m.c.closeErr
|
||
|
default:
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (m *mu) unlock() {
|
||
|
select {
|
||
|
case <-m.ch:
|
||
|
default:
|
||
|
}
|
||
|
}
|