You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
379 lines
8.7 KiB
379 lines
8.7 KiB
package websocket // import "nhooyr.io/websocket" |
|
|
|
import ( |
|
"bytes" |
|
"context" |
|
"errors" |
|
"fmt" |
|
"io" |
|
"net/http" |
|
"reflect" |
|
"runtime" |
|
"strings" |
|
"sync" |
|
"syscall/js" |
|
|
|
"nhooyr.io/websocket/internal/bpool" |
|
"nhooyr.io/websocket/internal/wsjs" |
|
"nhooyr.io/websocket/internal/xsync" |
|
) |
|
|
|
// Conn provides a wrapper around the browser WebSocket API. |
|
type Conn struct { |
|
ws wsjs.WebSocket |
|
|
|
// read limit for a message in bytes. |
|
msgReadLimit xsync.Int64 |
|
|
|
closingMu sync.Mutex |
|
isReadClosed xsync.Int64 |
|
closeOnce sync.Once |
|
closed chan struct{} |
|
closeErrOnce sync.Once |
|
closeErr error |
|
closeWasClean bool |
|
|
|
releaseOnClose func() |
|
releaseOnMessage func() |
|
|
|
readSignal chan struct{} |
|
readBufMu sync.Mutex |
|
readBuf []wsjs.MessageEvent |
|
} |
|
|
|
func (c *Conn) close(err error, wasClean bool) { |
|
c.closeOnce.Do(func() { |
|
runtime.SetFinalizer(c, nil) |
|
|
|
if !wasClean { |
|
err = fmt.Errorf("unclean connection close: %w", err) |
|
} |
|
c.setCloseErr(err) |
|
c.closeWasClean = wasClean |
|
close(c.closed) |
|
}) |
|
} |
|
|
|
func (c *Conn) init() { |
|
c.closed = make(chan struct{}) |
|
c.readSignal = make(chan struct{}, 1) |
|
|
|
c.msgReadLimit.Store(32768) |
|
|
|
c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { |
|
err := CloseError{ |
|
Code: StatusCode(e.Code), |
|
Reason: e.Reason, |
|
} |
|
// We do not know if we sent or received this close as |
|
// its possible the browser triggered it without us |
|
// explicitly sending it. |
|
c.close(err, e.WasClean) |
|
|
|
c.releaseOnClose() |
|
c.releaseOnMessage() |
|
}) |
|
|
|
c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) { |
|
c.readBufMu.Lock() |
|
defer c.readBufMu.Unlock() |
|
|
|
c.readBuf = append(c.readBuf, e) |
|
|
|
// Lets the read goroutine know there is definitely something in readBuf. |
|
select { |
|
case c.readSignal <- struct{}{}: |
|
default: |
|
} |
|
}) |
|
|
|
runtime.SetFinalizer(c, func(c *Conn) { |
|
c.setCloseErr(errors.New("connection garbage collected")) |
|
c.closeWithInternal() |
|
}) |
|
} |
|
|
|
func (c *Conn) closeWithInternal() { |
|
c.Close(StatusInternalError, "something went wrong") |
|
} |
|
|
|
// Read attempts to read a message from the connection. |
|
// The maximum time spent waiting is bounded by the context. |
|
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { |
|
if c.isReadClosed.Load() == 1 { |
|
return 0, nil, errors.New("WebSocket connection read closed") |
|
} |
|
|
|
typ, p, err := c.read(ctx) |
|
if err != nil { |
|
return 0, nil, fmt.Errorf("failed to read: %w", err) |
|
} |
|
if int64(len(p)) > c.msgReadLimit.Load() { |
|
err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load()) |
|
c.Close(StatusMessageTooBig, err.Error()) |
|
return 0, nil, err |
|
} |
|
return typ, p, nil |
|
} |
|
|
|
func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) { |
|
select { |
|
case <-ctx.Done(): |
|
c.Close(StatusPolicyViolation, "read timed out") |
|
return 0, nil, ctx.Err() |
|
case <-c.readSignal: |
|
case <-c.closed: |
|
return 0, nil, c.closeErr |
|
} |
|
|
|
c.readBufMu.Lock() |
|
defer c.readBufMu.Unlock() |
|
|
|
me := c.readBuf[0] |
|
// We copy the messages forward and decrease the size |
|
// of the slice to avoid reallocating. |
|
copy(c.readBuf, c.readBuf[1:]) |
|
c.readBuf = c.readBuf[:len(c.readBuf)-1] |
|
|
|
if len(c.readBuf) > 0 { |
|
// Next time we read, we'll grab the message. |
|
select { |
|
case c.readSignal <- struct{}{}: |
|
default: |
|
} |
|
} |
|
|
|
switch p := me.Data.(type) { |
|
case string: |
|
return MessageText, []byte(p), nil |
|
case []byte: |
|
return MessageBinary, p, nil |
|
default: |
|
panic("websocket: unexpected data type from wsjs OnMessage: " + reflect.TypeOf(me.Data).String()) |
|
} |
|
} |
|
|
|
// Ping is mocked out for Wasm. |
|
func (c *Conn) Ping(ctx context.Context) error { |
|
return nil |
|
} |
|
|
|
// Write writes a message of the given type to the connection. |
|
// Always non blocking. |
|
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { |
|
err := c.write(ctx, typ, p) |
|
if err != nil { |
|
// Have to ensure the WebSocket is closed after a write error |
|
// to match the Go API. It can only error if the message type |
|
// is unexpected or the passed bytes contain invalid UTF-8 for |
|
// MessageText. |
|
err := fmt.Errorf("failed to write: %w", err) |
|
c.setCloseErr(err) |
|
c.closeWithInternal() |
|
return err |
|
} |
|
return nil |
|
} |
|
|
|
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { |
|
if c.isClosed() { |
|
return c.closeErr |
|
} |
|
switch typ { |
|
case MessageBinary: |
|
return c.ws.SendBytes(p) |
|
case MessageText: |
|
return c.ws.SendText(string(p)) |
|
default: |
|
return fmt.Errorf("unexpected message type: %v", typ) |
|
} |
|
} |
|
|
|
// Close closes the WebSocket with the given code and reason. |
|
// It will wait until the peer responds with a close frame |
|
// or the connection is closed. |
|
// It thus performs the full WebSocket close handshake. |
|
func (c *Conn) Close(code StatusCode, reason string) error { |
|
err := c.exportedClose(code, reason) |
|
if err != nil { |
|
return fmt.Errorf("failed to close WebSocket: %w", err) |
|
} |
|
return nil |
|
} |
|
|
|
func (c *Conn) exportedClose(code StatusCode, reason string) error { |
|
c.closingMu.Lock() |
|
defer c.closingMu.Unlock() |
|
|
|
ce := fmt.Errorf("sent close: %w", CloseError{ |
|
Code: code, |
|
Reason: reason, |
|
}) |
|
|
|
if c.isClosed() { |
|
return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) |
|
} |
|
|
|
c.setCloseErr(ce) |
|
err := c.ws.Close(int(code), reason) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
<-c.closed |
|
if !c.closeWasClean { |
|
return c.closeErr |
|
} |
|
return nil |
|
} |
|
|
|
// Subprotocol returns the negotiated subprotocol. |
|
// An empty string means the default protocol. |
|
func (c *Conn) Subprotocol() string { |
|
return c.ws.Subprotocol() |
|
} |
|
|
|
// DialOptions represents the options available to pass to Dial. |
|
type DialOptions struct { |
|
// Subprotocols lists the subprotocols to negotiate with the server. |
|
Subprotocols []string |
|
} |
|
|
|
// Dial creates a new WebSocket connection to the given url with the given options. |
|
// The passed context bounds the maximum time spent waiting for the connection to open. |
|
// The returned *http.Response is always nil or a mock. It's only in the signature |
|
// to match the core API. |
|
func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { |
|
c, resp, err := dial(ctx, url, opts) |
|
if err != nil { |
|
return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err) |
|
} |
|
return c, resp, nil |
|
} |
|
|
|
func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { |
|
if opts == nil { |
|
opts = &DialOptions{} |
|
} |
|
|
|
url = strings.Replace(url, "http://", "ws://", 1) |
|
url = strings.Replace(url, "https://", "wss://", 1) |
|
|
|
ws, err := wsjs.New(url, opts.Subprotocols) |
|
if err != nil { |
|
return nil, nil, err |
|
} |
|
|
|
c := &Conn{ |
|
ws: ws, |
|
} |
|
c.init() |
|
|
|
opench := make(chan struct{}) |
|
releaseOpen := ws.OnOpen(func(e js.Value) { |
|
close(opench) |
|
}) |
|
defer releaseOpen() |
|
|
|
select { |
|
case <-ctx.Done(): |
|
c.Close(StatusPolicyViolation, "dial timed out") |
|
return nil, nil, ctx.Err() |
|
case <-opench: |
|
return c, &http.Response{ |
|
StatusCode: http.StatusSwitchingProtocols, |
|
}, nil |
|
case <-c.closed: |
|
return nil, nil, c.closeErr |
|
} |
|
} |
|
|
|
// Reader attempts to read a message from the connection. |
|
// The maximum time spent waiting is bounded by the context. |
|
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { |
|
typ, p, err := c.Read(ctx) |
|
if err != nil { |
|
return 0, nil, err |
|
} |
|
return typ, bytes.NewReader(p), nil |
|
} |
|
|
|
// Writer returns a writer to write a WebSocket data message to the connection. |
|
// It buffers the entire message in memory and then sends it when the writer |
|
// is closed. |
|
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { |
|
return writer{ |
|
c: c, |
|
ctx: ctx, |
|
typ: typ, |
|
b: bpool.Get(), |
|
}, nil |
|
} |
|
|
|
type writer struct { |
|
closed bool |
|
|
|
c *Conn |
|
ctx context.Context |
|
typ MessageType |
|
|
|
b *bytes.Buffer |
|
} |
|
|
|
func (w writer) Write(p []byte) (int, error) { |
|
if w.closed { |
|
return 0, errors.New("cannot write to closed writer") |
|
} |
|
n, err := w.b.Write(p) |
|
if err != nil { |
|
return n, fmt.Errorf("failed to write message: %w", err) |
|
} |
|
return n, nil |
|
} |
|
|
|
func (w writer) Close() error { |
|
if w.closed { |
|
return errors.New("cannot close closed writer") |
|
} |
|
w.closed = true |
|
defer bpool.Put(w.b) |
|
|
|
err := w.c.Write(w.ctx, w.typ, w.b.Bytes()) |
|
if err != nil { |
|
return fmt.Errorf("failed to close writer: %w", err) |
|
} |
|
return nil |
|
} |
|
|
|
// CloseRead implements *Conn.CloseRead for wasm. |
|
func (c *Conn) CloseRead(ctx context.Context) context.Context { |
|
c.isReadClosed.Store(1) |
|
|
|
ctx, cancel := context.WithCancel(ctx) |
|
go func() { |
|
defer cancel() |
|
c.read(ctx) |
|
c.Close(StatusPolicyViolation, "unexpected data message") |
|
}() |
|
return ctx |
|
} |
|
|
|
// SetReadLimit implements *Conn.SetReadLimit for wasm. |
|
func (c *Conn) SetReadLimit(n int64) { |
|
c.msgReadLimit.Store(n) |
|
} |
|
|
|
func (c *Conn) setCloseErr(err error) { |
|
c.closeErrOnce.Do(func() { |
|
c.closeErr = fmt.Errorf("WebSocket closed: %w", err) |
|
}) |
|
} |
|
|
|
func (c *Conn) isClosed() bool { |
|
select { |
|
case <-c.closed: |
|
return true |
|
default: |
|
return false |
|
} |
|
}
|
|
|