You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
293 lines
7.7 KiB
293 lines
7.7 KiB
3 years ago
|
// +build !js
|
||
|
|
||
|
package websocket
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"crypto/rand"
|
||
|
"encoding/base64"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"net/http"
|
||
|
"net/url"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"nhooyr.io/websocket/internal/errd"
|
||
|
)
|
||
|
|
||
|
// DialOptions represents Dial's options.
|
||
|
type DialOptions struct {
|
||
|
// HTTPClient is used for the connection.
|
||
|
// Its Transport must return writable bodies for WebSocket handshakes.
|
||
|
// http.Transport does beginning with Go 1.12.
|
||
|
HTTPClient *http.Client
|
||
|
|
||
|
// HTTPHeader specifies the HTTP headers included in the handshake request.
|
||
|
HTTPHeader http.Header
|
||
|
|
||
|
// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
|
||
|
Subprotocols []string
|
||
|
|
||
|
// CompressionMode controls the compression mode.
|
||
|
// Defaults to CompressionNoContextTakeover.
|
||
|
//
|
||
|
// See docs on CompressionMode for details.
|
||
|
CompressionMode CompressionMode
|
||
|
|
||
|
// CompressionThreshold controls the minimum size of a message before compression is applied.
|
||
|
//
|
||
|
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
|
||
|
// for CompressionContextTakeover.
|
||
|
CompressionThreshold int
|
||
|
}
|
||
|
|
||
|
// Dial performs a WebSocket handshake on url.
|
||
|
//
|
||
|
// The response is the WebSocket handshake response from the server.
|
||
|
// You never need to close resp.Body yourself.
|
||
|
//
|
||
|
// If an error occurs, the returned response may be non nil.
|
||
|
// However, you can only read the first 1024 bytes of the body.
|
||
|
//
|
||
|
// This function requires at least Go 1.12 as it uses a new feature
|
||
|
// in net/http to perform WebSocket handshakes.
|
||
|
// See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
|
||
|
//
|
||
|
// URLs with http/https schemes will work and are interpreted as ws/wss.
|
||
|
func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
|
||
|
return dial(ctx, u, opts, nil)
|
||
|
}
|
||
|
|
||
|
func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
|
||
|
defer errd.Wrap(&err, "failed to WebSocket dial")
|
||
|
|
||
|
if opts == nil {
|
||
|
opts = &DialOptions{}
|
||
|
}
|
||
|
|
||
|
opts = &*opts
|
||
|
if opts.HTTPClient == nil {
|
||
|
opts.HTTPClient = http.DefaultClient
|
||
|
} else if opts.HTTPClient.Timeout > 0 {
|
||
|
var cancel context.CancelFunc
|
||
|
|
||
|
ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout)
|
||
|
defer cancel()
|
||
|
|
||
|
newClient := *opts.HTTPClient
|
||
|
newClient.Timeout = 0
|
||
|
opts.HTTPClient = &newClient
|
||
|
}
|
||
|
|
||
|
if opts.HTTPHeader == nil {
|
||
|
opts.HTTPHeader = http.Header{}
|
||
|
}
|
||
|
|
||
|
secWebSocketKey, err := secWebSocketKey(rand)
|
||
|
if err != nil {
|
||
|
return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
|
||
|
}
|
||
|
|
||
|
var copts *compressionOptions
|
||
|
if opts.CompressionMode != CompressionDisabled {
|
||
|
copts = opts.CompressionMode.opts()
|
||
|
}
|
||
|
|
||
|
resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
|
||
|
if err != nil {
|
||
|
return nil, resp, err
|
||
|
}
|
||
|
respBody := resp.Body
|
||
|
resp.Body = nil
|
||
|
defer func() {
|
||
|
if err != nil {
|
||
|
// We read a bit of the body for easier debugging.
|
||
|
r := io.LimitReader(respBody, 1024)
|
||
|
|
||
|
timer := time.AfterFunc(time.Second*3, func() {
|
||
|
respBody.Close()
|
||
|
})
|
||
|
defer timer.Stop()
|
||
|
|
||
|
b, _ := ioutil.ReadAll(r)
|
||
|
respBody.Close()
|
||
|
resp.Body = ioutil.NopCloser(bytes.NewReader(b))
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
|
||
|
if err != nil {
|
||
|
return nil, resp, err
|
||
|
}
|
||
|
|
||
|
rwc, ok := respBody.(io.ReadWriteCloser)
|
||
|
if !ok {
|
||
|
return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
|
||
|
}
|
||
|
|
||
|
return newConn(connConfig{
|
||
|
subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
|
||
|
rwc: rwc,
|
||
|
client: true,
|
||
|
copts: copts,
|
||
|
flateThreshold: opts.CompressionThreshold,
|
||
|
br: getBufioReader(rwc),
|
||
|
bw: getBufioWriter(rwc),
|
||
|
}), resp, nil
|
||
|
}
|
||
|
|
||
|
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
|
||
|
u, err := url.Parse(urls)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to parse url: %w", err)
|
||
|
}
|
||
|
|
||
|
switch u.Scheme {
|
||
|
case "ws":
|
||
|
u.Scheme = "http"
|
||
|
case "wss":
|
||
|
u.Scheme = "https"
|
||
|
case "http", "https":
|
||
|
default:
|
||
|
return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
|
||
|
}
|
||
|
|
||
|
req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
|
||
|
req.Header = opts.HTTPHeader.Clone()
|
||
|
req.Header.Set("Connection", "Upgrade")
|
||
|
req.Header.Set("Upgrade", "websocket")
|
||
|
req.Header.Set("Sec-WebSocket-Version", "13")
|
||
|
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
|
||
|
if len(opts.Subprotocols) > 0 {
|
||
|
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
|
||
|
}
|
||
|
if copts != nil {
|
||
|
copts.setHeader(req.Header)
|
||
|
}
|
||
|
|
||
|
resp, err := opts.HTTPClient.Do(req)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to send handshake request: %w", err)
|
||
|
}
|
||
|
return resp, nil
|
||
|
}
|
||
|
|
||
|
func secWebSocketKey(rr io.Reader) (string, error) {
|
||
|
if rr == nil {
|
||
|
rr = rand.Reader
|
||
|
}
|
||
|
b := make([]byte, 16)
|
||
|
_, err := io.ReadFull(rr, b)
|
||
|
if err != nil {
|
||
|
return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
|
||
|
}
|
||
|
return base64.StdEncoding.EncodeToString(b), nil
|
||
|
}
|
||
|
|
||
|
func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
|
||
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||
|
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
|
||
|
}
|
||
|
|
||
|
if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") {
|
||
|
return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
|
||
|
}
|
||
|
|
||
|
if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") {
|
||
|
return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
|
||
|
}
|
||
|
|
||
|
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
|
||
|
return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
|
||
|
resp.Header.Get("Sec-WebSocket-Accept"),
|
||
|
secWebSocketKey,
|
||
|
)
|
||
|
}
|
||
|
|
||
|
err := verifySubprotocol(opts.Subprotocols, resp)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return verifyServerExtensions(copts, resp.Header)
|
||
|
}
|
||
|
|
||
|
func verifySubprotocol(subprotos []string, resp *http.Response) error {
|
||
|
proto := resp.Header.Get("Sec-WebSocket-Protocol")
|
||
|
if proto == "" {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
for _, sp2 := range subprotos {
|
||
|
if strings.EqualFold(sp2, proto) {
|
||
|
return nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
|
||
|
}
|
||
|
|
||
|
func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
|
||
|
exts := websocketExtensions(h)
|
||
|
if len(exts) == 0 {
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
ext := exts[0]
|
||
|
if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
|
||
|
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
|
||
|
}
|
||
|
|
||
|
copts = &*copts
|
||
|
|
||
|
for _, p := range ext.params {
|
||
|
switch p {
|
||
|
case "client_no_context_takeover":
|
||
|
copts.clientNoContextTakeover = true
|
||
|
continue
|
||
|
case "server_no_context_takeover":
|
||
|
copts.serverNoContextTakeover = true
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
|
||
|
}
|
||
|
|
||
|
return copts, nil
|
||
|
}
|
||
|
|
||
|
var bufioReaderPool sync.Pool
|
||
|
|
||
|
func getBufioReader(r io.Reader) *bufio.Reader {
|
||
|
br, ok := bufioReaderPool.Get().(*bufio.Reader)
|
||
|
if !ok {
|
||
|
return bufio.NewReader(r)
|
||
|
}
|
||
|
br.Reset(r)
|
||
|
return br
|
||
|
}
|
||
|
|
||
|
func putBufioReader(br *bufio.Reader) {
|
||
|
bufioReaderPool.Put(br)
|
||
|
}
|
||
|
|
||
|
var bufioWriterPool sync.Pool
|
||
|
|
||
|
func getBufioWriter(w io.Writer) *bufio.Writer {
|
||
|
bw, ok := bufioWriterPool.Get().(*bufio.Writer)
|
||
|
if !ok {
|
||
|
return bufio.NewWriter(w)
|
||
|
}
|
||
|
bw.Reset(w)
|
||
|
return bw
|
||
|
}
|
||
|
|
||
|
func putBufioWriter(bw *bufio.Writer) {
|
||
|
bufioWriterPool.Put(bw)
|
||
|
}
|