You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
201 lines
5.3 KiB
201 lines
5.3 KiB
package pgproto3 |
|
|
|
import ( |
|
"encoding/binary" |
|
"errors" |
|
"fmt" |
|
"io" |
|
) |
|
|
|
// Frontend acts as a client for the PostgreSQL wire protocol version 3. |
|
type Frontend struct { |
|
cr ChunkReader |
|
w io.Writer |
|
|
|
// Backend message flyweights |
|
authenticationOk AuthenticationOk |
|
authenticationCleartextPassword AuthenticationCleartextPassword |
|
authenticationMD5Password AuthenticationMD5Password |
|
authenticationSASL AuthenticationSASL |
|
authenticationSASLContinue AuthenticationSASLContinue |
|
authenticationSASLFinal AuthenticationSASLFinal |
|
backendKeyData BackendKeyData |
|
bindComplete BindComplete |
|
closeComplete CloseComplete |
|
commandComplete CommandComplete |
|
copyBothResponse CopyBothResponse |
|
copyData CopyData |
|
copyInResponse CopyInResponse |
|
copyOutResponse CopyOutResponse |
|
copyDone CopyDone |
|
dataRow DataRow |
|
emptyQueryResponse EmptyQueryResponse |
|
errorResponse ErrorResponse |
|
functionCallResponse FunctionCallResponse |
|
noData NoData |
|
noticeResponse NoticeResponse |
|
notificationResponse NotificationResponse |
|
parameterDescription ParameterDescription |
|
parameterStatus ParameterStatus |
|
parseComplete ParseComplete |
|
readyForQuery ReadyForQuery |
|
rowDescription RowDescription |
|
portalSuspended PortalSuspended |
|
|
|
bodyLen int |
|
msgType byte |
|
partialMsg bool |
|
authType uint32 |
|
} |
|
|
|
// NewFrontend creates a new Frontend. |
|
func NewFrontend(cr ChunkReader, w io.Writer) *Frontend { |
|
return &Frontend{cr: cr, w: w} |
|
} |
|
|
|
// Send sends a message to the backend. |
|
func (f *Frontend) Send(msg FrontendMessage) error { |
|
_, err := f.w.Write(msg.Encode(nil)) |
|
return err |
|
} |
|
|
|
func translateEOFtoErrUnexpectedEOF(err error) error { |
|
if err == io.EOF { |
|
return io.ErrUnexpectedEOF |
|
} |
|
return err |
|
} |
|
|
|
// Receive receives a message from the backend. The returned message is only valid until the next call to Receive. |
|
func (f *Frontend) Receive() (BackendMessage, error) { |
|
if !f.partialMsg { |
|
header, err := f.cr.Next(5) |
|
if err != nil { |
|
return nil, translateEOFtoErrUnexpectedEOF(err) |
|
} |
|
|
|
f.msgType = header[0] |
|
f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 |
|
f.partialMsg = true |
|
} |
|
|
|
msgBody, err := f.cr.Next(f.bodyLen) |
|
if err != nil { |
|
return nil, translateEOFtoErrUnexpectedEOF(err) |
|
} |
|
|
|
f.partialMsg = false |
|
|
|
var msg BackendMessage |
|
switch f.msgType { |
|
case '1': |
|
msg = &f.parseComplete |
|
case '2': |
|
msg = &f.bindComplete |
|
case '3': |
|
msg = &f.closeComplete |
|
case 'A': |
|
msg = &f.notificationResponse |
|
case 'c': |
|
msg = &f.copyDone |
|
case 'C': |
|
msg = &f.commandComplete |
|
case 'd': |
|
msg = &f.copyData |
|
case 'D': |
|
msg = &f.dataRow |
|
case 'E': |
|
msg = &f.errorResponse |
|
case 'G': |
|
msg = &f.copyInResponse |
|
case 'H': |
|
msg = &f.copyOutResponse |
|
case 'I': |
|
msg = &f.emptyQueryResponse |
|
case 'K': |
|
msg = &f.backendKeyData |
|
case 'n': |
|
msg = &f.noData |
|
case 'N': |
|
msg = &f.noticeResponse |
|
case 'R': |
|
var err error |
|
msg, err = f.findAuthenticationMessageType(msgBody) |
|
if err != nil { |
|
return nil, err |
|
} |
|
case 's': |
|
msg = &f.portalSuspended |
|
case 'S': |
|
msg = &f.parameterStatus |
|
case 't': |
|
msg = &f.parameterDescription |
|
case 'T': |
|
msg = &f.rowDescription |
|
case 'V': |
|
msg = &f.functionCallResponse |
|
case 'W': |
|
msg = &f.copyBothResponse |
|
case 'Z': |
|
msg = &f.readyForQuery |
|
default: |
|
return nil, fmt.Errorf("unknown message type: %c", f.msgType) |
|
} |
|
|
|
err = msg.Decode(msgBody) |
|
return msg, err |
|
} |
|
|
|
// Authentication message type constants. |
|
// See src/include/libpq/pqcomm.h for all |
|
// constants. |
|
const ( |
|
AuthTypeOk = 0 |
|
AuthTypeCleartextPassword = 3 |
|
AuthTypeMD5Password = 5 |
|
AuthTypeSCMCreds = 6 |
|
AuthTypeGSS = 7 |
|
AuthTypeGSSCont = 8 |
|
AuthTypeSSPI = 9 |
|
AuthTypeSASL = 10 |
|
AuthTypeSASLContinue = 11 |
|
AuthTypeSASLFinal = 12 |
|
) |
|
|
|
func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) { |
|
if len(src) < 4 { |
|
return nil, errors.New("authentication message too short") |
|
} |
|
f.authType = binary.BigEndian.Uint32(src[:4]) |
|
|
|
switch f.authType { |
|
case AuthTypeOk: |
|
return &f.authenticationOk, nil |
|
case AuthTypeCleartextPassword: |
|
return &f.authenticationCleartextPassword, nil |
|
case AuthTypeMD5Password: |
|
return &f.authenticationMD5Password, nil |
|
case AuthTypeSCMCreds: |
|
return nil, errors.New("AuthTypeSCMCreds is unimplemented") |
|
case AuthTypeGSS: |
|
return nil, errors.New("AuthTypeGSS is unimplemented") |
|
case AuthTypeGSSCont: |
|
return nil, errors.New("AuthTypeGSSCont is unimplemented") |
|
case AuthTypeSSPI: |
|
return nil, errors.New("AuthTypeSSPI is unimplemented") |
|
case AuthTypeSASL: |
|
return &f.authenticationSASL, nil |
|
case AuthTypeSASLContinue: |
|
return &f.authenticationSASLContinue, nil |
|
case AuthTypeSASLFinal: |
|
return &f.authenticationSASLFinal, nil |
|
default: |
|
return nil, fmt.Errorf("unknown authentication type: %d", f.authType) |
|
} |
|
} |
|
|
|
// GetAuthType returns the authType used in the current state of the frontend. |
|
// See SetAuthType for more information. |
|
func (f *Frontend) GetAuthType() uint32 { |
|
return f.authType |
|
}
|
|
|