You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
168 lines
3.7 KiB
168 lines
3.7 KiB
// Copyright 2018 The Go Authors. All rights reserved. |
|
// Use of this source code is governed by a BSD-style |
|
// license that can be found in the LICENSE file. |
|
|
|
package socks |
|
|
|
import ( |
|
"context" |
|
"errors" |
|
"io" |
|
"net" |
|
"strconv" |
|
"time" |
|
) |
|
|
|
var ( |
|
noDeadline = time.Time{} |
|
aLongTimeAgo = time.Unix(1, 0) |
|
) |
|
|
|
func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) { |
|
host, port, err := splitHostPort(address) |
|
if err != nil { |
|
return nil, err |
|
} |
|
if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { |
|
c.SetDeadline(deadline) |
|
defer c.SetDeadline(noDeadline) |
|
} |
|
if ctx != context.Background() { |
|
errCh := make(chan error, 1) |
|
done := make(chan struct{}) |
|
defer func() { |
|
close(done) |
|
if ctxErr == nil { |
|
ctxErr = <-errCh |
|
} |
|
}() |
|
go func() { |
|
select { |
|
case <-ctx.Done(): |
|
c.SetDeadline(aLongTimeAgo) |
|
errCh <- ctx.Err() |
|
case <-done: |
|
errCh <- nil |
|
} |
|
}() |
|
} |
|
|
|
b := make([]byte, 0, 6+len(host)) // the size here is just an estimate |
|
b = append(b, Version5) |
|
if len(d.AuthMethods) == 0 || d.Authenticate == nil { |
|
b = append(b, 1, byte(AuthMethodNotRequired)) |
|
} else { |
|
ams := d.AuthMethods |
|
if len(ams) > 255 { |
|
return nil, errors.New("too many authentication methods") |
|
} |
|
b = append(b, byte(len(ams))) |
|
for _, am := range ams { |
|
b = append(b, byte(am)) |
|
} |
|
} |
|
if _, ctxErr = c.Write(b); ctxErr != nil { |
|
return |
|
} |
|
|
|
if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil { |
|
return |
|
} |
|
if b[0] != Version5 { |
|
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) |
|
} |
|
am := AuthMethod(b[1]) |
|
if am == AuthMethodNoAcceptableMethods { |
|
return nil, errors.New("no acceptable authentication methods") |
|
} |
|
if d.Authenticate != nil { |
|
if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil { |
|
return |
|
} |
|
} |
|
|
|
b = b[:0] |
|
b = append(b, Version5, byte(d.cmd), 0) |
|
if ip := net.ParseIP(host); ip != nil { |
|
if ip4 := ip.To4(); ip4 != nil { |
|
b = append(b, AddrTypeIPv4) |
|
b = append(b, ip4...) |
|
} else if ip6 := ip.To16(); ip6 != nil { |
|
b = append(b, AddrTypeIPv6) |
|
b = append(b, ip6...) |
|
} else { |
|
return nil, errors.New("unknown address type") |
|
} |
|
} else { |
|
if len(host) > 255 { |
|
return nil, errors.New("FQDN too long") |
|
} |
|
b = append(b, AddrTypeFQDN) |
|
b = append(b, byte(len(host))) |
|
b = append(b, host...) |
|
} |
|
b = append(b, byte(port>>8), byte(port)) |
|
if _, ctxErr = c.Write(b); ctxErr != nil { |
|
return |
|
} |
|
|
|
if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil { |
|
return |
|
} |
|
if b[0] != Version5 { |
|
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) |
|
} |
|
if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded { |
|
return nil, errors.New("unknown error " + cmdErr.String()) |
|
} |
|
if b[2] != 0 { |
|
return nil, errors.New("non-zero reserved field") |
|
} |
|
l := 2 |
|
var a Addr |
|
switch b[3] { |
|
case AddrTypeIPv4: |
|
l += net.IPv4len |
|
a.IP = make(net.IP, net.IPv4len) |
|
case AddrTypeIPv6: |
|
l += net.IPv6len |
|
a.IP = make(net.IP, net.IPv6len) |
|
case AddrTypeFQDN: |
|
if _, err := io.ReadFull(c, b[:1]); err != nil { |
|
return nil, err |
|
} |
|
l += int(b[0]) |
|
default: |
|
return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3]))) |
|
} |
|
if cap(b) < l { |
|
b = make([]byte, l) |
|
} else { |
|
b = b[:l] |
|
} |
|
if _, ctxErr = io.ReadFull(c, b); ctxErr != nil { |
|
return |
|
} |
|
if a.IP != nil { |
|
copy(a.IP, b) |
|
} else { |
|
a.Name = string(b[:len(b)-2]) |
|
} |
|
a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1]) |
|
return &a, nil |
|
} |
|
|
|
func splitHostPort(address string) (string, int, error) { |
|
host, port, err := net.SplitHostPort(address) |
|
if err != nil { |
|
return "", 0, err |
|
} |
|
portnum, err := strconv.Atoi(port) |
|
if err != nil { |
|
return "", 0, err |
|
} |
|
if 1 > portnum || portnum > 0xffff { |
|
return "", 0, errors.New("port number out of range " + port) |
|
} |
|
return host, portnum, nil |
|
}
|
|
|