You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
161 lines
3.8 KiB
161 lines
3.8 KiB
package pgx |
|
|
|
import ( |
|
"database/sql/driver" |
|
"fmt" |
|
"reflect" |
|
|
|
"github.com/jackc/pgtype" |
|
) |
|
|
|
type extendedQueryBuilder struct { |
|
paramValues [][]byte |
|
paramValueBytes []byte |
|
paramFormats []int16 |
|
resultFormats []int16 |
|
} |
|
|
|
func (eqb *extendedQueryBuilder) AppendParam(ci *pgtype.ConnInfo, oid uint32, arg interface{}) error { |
|
f := chooseParameterFormatCode(ci, oid, arg) |
|
eqb.paramFormats = append(eqb.paramFormats, f) |
|
|
|
v, err := eqb.encodeExtendedParamValue(ci, oid, f, arg) |
|
if err != nil { |
|
return err |
|
} |
|
eqb.paramValues = append(eqb.paramValues, v) |
|
|
|
return nil |
|
} |
|
|
|
func (eqb *extendedQueryBuilder) AppendResultFormat(f int16) { |
|
eqb.resultFormats = append(eqb.resultFormats, f) |
|
} |
|
|
|
// Reset readies eqb to build another query. |
|
func (eqb *extendedQueryBuilder) Reset() { |
|
eqb.paramValues = eqb.paramValues[0:0] |
|
eqb.paramValueBytes = eqb.paramValueBytes[0:0] |
|
eqb.paramFormats = eqb.paramFormats[0:0] |
|
eqb.resultFormats = eqb.resultFormats[0:0] |
|
|
|
if cap(eqb.paramValues) > 64 { |
|
eqb.paramValues = make([][]byte, 0, 64) |
|
} |
|
|
|
if cap(eqb.paramValueBytes) > 256 { |
|
eqb.paramValueBytes = make([]byte, 0, 256) |
|
} |
|
|
|
if cap(eqb.paramFormats) > 64 { |
|
eqb.paramFormats = make([]int16, 0, 64) |
|
} |
|
if cap(eqb.resultFormats) > 64 { |
|
eqb.resultFormats = make([]int16, 0, 64) |
|
} |
|
} |
|
|
|
func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid uint32, formatCode int16, arg interface{}) ([]byte, error) { |
|
if arg == nil { |
|
return nil, nil |
|
} |
|
|
|
refVal := reflect.ValueOf(arg) |
|
argIsPtr := refVal.Kind() == reflect.Ptr |
|
|
|
if argIsPtr && refVal.IsNil() { |
|
return nil, nil |
|
} |
|
|
|
if eqb.paramValueBytes == nil { |
|
eqb.paramValueBytes = make([]byte, 0, 128) |
|
} |
|
|
|
var err error |
|
var buf []byte |
|
pos := len(eqb.paramValueBytes) |
|
|
|
if arg, ok := arg.(string); ok { |
|
return []byte(arg), nil |
|
} |
|
|
|
if formatCode == TextFormatCode { |
|
if arg, ok := arg.(pgtype.TextEncoder); ok { |
|
buf, err = arg.EncodeText(ci, eqb.paramValueBytes) |
|
if err != nil { |
|
return nil, err |
|
} |
|
if buf == nil { |
|
return nil, nil |
|
} |
|
eqb.paramValueBytes = buf |
|
return eqb.paramValueBytes[pos:], nil |
|
} |
|
} else if formatCode == BinaryFormatCode { |
|
if arg, ok := arg.(pgtype.BinaryEncoder); ok { |
|
buf, err = arg.EncodeBinary(ci, eqb.paramValueBytes) |
|
if err != nil { |
|
return nil, err |
|
} |
|
if buf == nil { |
|
return nil, nil |
|
} |
|
eqb.paramValueBytes = buf |
|
return eqb.paramValueBytes[pos:], nil |
|
} |
|
} |
|
|
|
if argIsPtr { |
|
// We have already checked that arg is not pointing to nil, |
|
// so it is safe to dereference here. |
|
arg = refVal.Elem().Interface() |
|
return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg) |
|
} |
|
|
|
if dt, ok := ci.DataTypeForOID(oid); ok { |
|
value := dt.Value |
|
err := value.Set(arg) |
|
if err != nil { |
|
{ |
|
if arg, ok := arg.(driver.Valuer); ok { |
|
v, err := callValuerValue(arg) |
|
if err != nil { |
|
return nil, err |
|
} |
|
return eqb.encodeExtendedParamValue(ci, oid, formatCode, v) |
|
} |
|
} |
|
|
|
return nil, err |
|
} |
|
|
|
return eqb.encodeExtendedParamValue(ci, oid, formatCode, value) |
|
} |
|
|
|
// There is no data type registered for the destination OID, but maybe there is data type registered for the arg |
|
// type. If so use it's text encoder (if available). |
|
if dt, ok := ci.DataTypeForValue(arg); ok { |
|
value := dt.Value |
|
if textEncoder, ok := value.(pgtype.TextEncoder); ok { |
|
err := value.Set(arg) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
buf, err = textEncoder.EncodeText(ci, eqb.paramValueBytes) |
|
if err != nil { |
|
return nil, err |
|
} |
|
if buf == nil { |
|
return nil, nil |
|
} |
|
eqb.paramValueBytes = buf |
|
return eqb.paramValueBytes[pos:], nil |
|
} |
|
} |
|
|
|
if strippedArg, ok := stripNamedType(&refVal); ok { |
|
return eqb.encodeExtendedParamValue(ci, oid, formatCode, strippedArg) |
|
} |
|
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) |
|
}
|
|
|